首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】9.Tensor的维度变换


本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





维度变换操作是张量(Tensor)操作中非常基础和重要的一部分,它允许我们重新排列、扩展或缩减张量的维度,以满足不同深度学习模型和数据处理任务的需求。
01

Tensor的维度:


维度(Dimension)是描述Tensor规模和结构的基本属性。简单来说,Tensor的维度就是其索引的数量,也即我们常说的“几维”。
例如,一个标量(Scalar)可以被视为0维Tensor,因为它没有索引;一个向量(Vector)是一维Tensor,因为它有一个索引;一个矩阵(Matrix)是二维Tensor,因为它有两个索引(行和列);更高维的Tensor则有更多的索引。
  • Tensor的维度计算方法:
PyTorch中,查看Tensor的维度主要有两种方法:使用.shape属性或使用size()方法。
  • shape属性:shapeTensor对象的属性,它返回一个表示Tensor维度的元组。
  • size()方法:size()Tensor对象的一个方法,它也可以用来获取Tensor的维度信息。

除此之外,还有以下两种方式:

  • tensor.dim():用来查看Tensor的维度,等价于len(tensor.shape);
  • tensor.numel():查看Tensor中元素的数量
>>> a = torch.rand(2,3,28,28)

# shape属性
>>> a.shape
torch.Size([232828])

# size()方法
>>> a.size()
torch.Size([232828])

# dim()方法
>>> a.dim()
4

# numel()方法
>>> a.numel() # 2*3*28*28
4704


02

维度变换操作:


PyTorch提供了多种维度变换操作:
  • view() / reshape():这两个函数用于改变张量的形状,而不改变其数据。view()在形状兼容的情况下更为灵活,而reshape()在某些情况下会返回一个新的副本。
  • squeeze() / unsqueeze()squeeze()用于移除张量中所有大小为1的维度,而unsqueeze()则用于在指定位置添加一个新的大小为1的维度。
  • expand() / broadcast_to():这些函数用于扩展张量的维度以匹配另一个张量的形状,常用于广播操作。
  • flatten() / ravel():这两个函数用于将多维张量展平成一维张量,但flatten()通常更常用。
  • view() / reshape():
这两个函数用于改变张量的形状,而不改变其数据。
>>> a = torch.rand(2,3,28,28)

>>> a.reshape(2,3,28*28).shape
torch.Size([23784])

>>> a.view(2,3,-1).shape
torch.Size([23784])
view()仅能处理内存空间中连续的Tensor,经过view操作之后的Tensor仍然共享存储区
reshape()首先会把内存空间中不连续的Tensor变成连续的,然后在进行形状变化,等价于tensor.contiguous().view()
>>> a = torch.rand(2,3,28,28)

# 经过维度交换操作后,会使得tensor不再连续
>>> b = a.transpose(1,3)
>>> b.shape
torch.Size([228283])

# reshape方法仍然适用
>>> b.reshape(2,-1)
tensor([[0.22350.19180.9721,  ..., 0.23020.47880.8718],
        [0.60340.85110.1370,  ..., 0.00800.65280.6254]])

# view方法会报错
>>> b.view(2,-1)
Traceback (most recent call last):
  File "<stdin>", line 1in <module>
RuntimeError: view size is not compatible with input tensor s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

# 通过contiguous()方法将tensor变得连续,再使用view方法
>>> b.contiguous().view(2,-1)
tensor([[0.22350.19180.9721,  ..., 0.23020.47880.8718],
        [0.60340.85110.1370,  ..., 0.00800.65280.6254]])

  • squeeze() / unsqueeze():
unsqueeze()方法用于在张量中增加一个新的维度,这个新维度的大小为1。这在你想要增加张量的维度数时非常有用,例如为了进行广播操作。正数表示在该维度原本的位置前面插入这个新增加的维度,负数表示在该维度原本的位置之后插入。
>>> a = torch.rand(2,3,28,28)
>>> a.shape
torch.Size([232828])

>>> a.unsqueeze(0).shape
torch.Size([1232828])

>>> a.unsqueeze(-1).shape
torch.Size([2328281])
squeeze()用于删减维度,删减维度实际上是一个压榨的过程,直观地看是把那些多余的[]给去掉,也就是只是去删除那些size=1的维度。这在想要减少张量的维度数时非常有用。
>>> a = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([13211])

# 不给参数时,会将能删减的都删减掉(dim的size为1的位置)
>>> a.squeeze().shape
torch.Size([32])

# 给定具体的维度进行删减
>>> a.squeeze(dim=0).shape
torch.Size([3211])

# squeeze()只会删减size为1的位置,因为此时为32,因此不会被删减,也不会报错
>>> a.squeeze(1).shape
torch.Size([13211])

  • expand() / broadcast_to():
两个方法方法都通过张量广播(broadcasting)技术扩展张量的维度,但它们的行为和用法略有不同。

Tensor的Broadcasting操作跟pandas中的广播操作一样,能够自动实现维度扩展,并且不需要对数据进行拷贝,进而能节省内存。

  • 从最后面的维度开始匹配。

  • 在前面插入若干维度。

  • 将维度的size从1通过expand变到和某个Tensor相同的维度。

总之,Broadcasting操作就是自动实现了若干unsqueeze和expand操作,以使两个tensor的shape一致,从而完成某些操作(往往是加法)。

expand()函数是张量对象的成员方法(直接附加到张量对象上的)。
>>> a = torch.rand(1,3,1,1)

# 只能扩展size=1的维度,给定的参数是新的size
>>> a.expand(4,3,28,28).shape
torch.Size([432828])

# -1表示不改变该维度的size
>>> a.expand(-1,3,28,28).shape
torch.Size([132828])

# 使用expand_as()扩展为与其他张量相同的大小
>>> x = torch.rand(4,3,28,28)
>>> y = torch.rand(1,3,1,1)
>>> y.expand_as(x).shape
torch.Size([432828])
broadcast_to()是函数级别的(位于torch模块模块下)。
>>> a = torch.rand(1,3,1,1)

# 只能扩展size=1的维度,给定的参数是新的size
>>> torch.broadcast_to(a , (4,-1,28,28)).shape
torch.Size([432828])

# -1表示不改变该维度的size
>>> torch.broadcast_to(a , (-1,3,28,28)).shape
torch.Size([132828])

  • flatten() / ravel():
两个方法都可以用来对多维张量展平操作。但它们的行为和用法略有不同。
ravel()函数用于将多维张量展平成一维张量,函数语法如下:
torch.ravel(input)
其中,input输入的张量。
>>> t = torch.arange(1,7).reshape(2,3)
>>> t
tensor([[123],
        [456]])

>>> t.ravel()
tensor([123456])
flatten()函数用于将多维张量展平成一维,或者展平指定的连续维度,函数语法如下
torch.flatten(input, start_dim=0, end_dim=-1
其中,input输入的张量,start_dim、end_dim分别为展平开始和结束的维度。
>>> t = torch.arange(1,7).reshape(2,3)
>>> t
tensor([[123],
        [456]])

# 展平成一维
>>> t.flatten()
tensor([123456])

# 指定展平维度
>>> t = torch.arange(1,121).reshape(2,3,4,5)
>>> t.flatten(start_dim=1 , end_dim=2).shape
torch.Size([2125])


更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments