首页人工智能Pytorch5.Tensor的维度变换...

5.Tensor的维度变换

1.维度变换:

调整Tensor的shape,新版本中提供了reshape()方法,在老版本中这个函数是view(),两者功能上都是一样的。

>>> a = torch.rand(4,1,28,28)
>>> a.shape
torch.Size([4, 1, 28, 28])
>>> a.reshape(4,28*28)
tensor([[0.8371, 0.9728, 0.5271,  ..., 0.3528, 0.8527, 0.6425],
        [0.5836, 0.0016, 0.3612,  ..., 0.8342, 0.5880, 0.9205],
        [0.6959, 0.9789, 0.6670,  ..., 0.4895, 0.2700, 0.4503],
        [0.0050, 0.3783, 0.7148,  ..., 0.4339, 0.3403, 0.2258]])
>>> a.reshape(4,28*28).shape
torch.Size([4, 784])
>>> a.reshape(4*28,28).shape
torch.Size([112, 28])
>>> a.reshape(4*1,28,28).shape
torch.Size([4, 28, 28])

2.增加与删减维度:

  • unsqueeze,增加维度

正数表示在该维度原本的位置前面插入这个新增加的维度,负数表示在该维度原本的位置之后插入。

>>> a.shape
torch.Size([4, 1, 28, 28])
>>> a.unsqueeze(0).shape
torch.Size([1, 4, 1, 28, 28])
>>> a.unsqueeze(-1).shape
torch.Size([4, 1, 28, 28, 1])
>>> a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])
>>> a.unsqueeze(-4).shape
torch.Size([4, 1, 1, 28, 28])
>>> a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])
# 数据维度范围为[-a.dim()-1,a.dim()+1)
# 因为5不在范围内,所以会报错
>>> a.unsqueeze(5).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
# 此时生成了一个dim 0的tensor
>>> a = torch.tensor([1.2,3.4])
>>> a.shape
torch.Size([2])
# 在最后位置的后面增加一个维度,shape将从[2]变为[2, 1]
>>> a.unsqueeze(-1)
tensor([[1.2000],
        [3.4000]])
>>> a.unsqueeze(-1).shape
torch.Size([2, 1])
# 在0位置之前增加一个维度,shape将从[2]变为[2, 1]
>>> a.unsqueeze(0)
tensor([[1.2000, 3.4000]])
>>> a.unsqueeze(0).shape
torch.Size([1, 2])
  • squeeze,删减维度

删减维度实际上是一个压榨的过程,直观地看是把那些多余的[]给去掉,也就是只是去删除那些size=1的维度。

>>> a = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([1, 32, 1, 1])
# .squeeze()不给参数时,会将能删减的都删减掉(dim的size为1的位置)
>>> a.squeeze().shape
torch.Size([32])
# 给定具体的维度进行删减
>>> a.squeeze(0).shape
torch.Size([32, 1, 1])
>>> a.squeeze(-1).shape
torch.Size([1, 32, 1])
# .squeeze()只会删减size为1的位置,因为此时size为32,因此不会被删减,也不会报错
>>> a.squeeze(1).shape
torch.Size([1, 32, 1, 1])
>>> a.squeeze(-4).shape
torch.Size([32, 1, 1])

3.扩展维度:

  • expand,不会主动的复制数据,因此执行速度快并且节约内存,推荐使用此方法

expand就是在某个size=1的维度上改变size,改成更大的一个size,实际就是在每个size=1的维度上的标量的广播操作。

>>> a = torch.rand(4,32,14,14)
>>> b = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> b.shape
torch.Size([1, 32, 1, 1])
# .expand()只能扩展size=1的维度,给定的参数是新的size
>>> b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
# -1表示不改变该维度的size
>>> b.expand(-1,32,-1,14).shape
torch.Size([1, 32, 1, 14])
# -4
>>> b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
<!-- wp:code {"lineNumbers":true} -->
<pre class="wp-block-code"><code lang="python" class="language-python line-numbers">>>> a = torch.rand(4,32,14,14)
>>> b = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> b.shape
torch.Size([1, 32, 1, 1])
# 使用expand_as()扩展为与其他张量相同的大小
>>> b.expand_as(a).shape
torch.Size([4, 32, 14, 14])
# .expand()只能扩展size=1的维度,给定的参数是新的size
>>> b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
# -1表示不改变该维度的size
>>> b.expand(-1,32,-1,14).shape
torch.Size([1, 32, 1, 14])
# 一个小bug
>>> b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
  • repeat,该方式会重新申请内存空间,主动复制数据

repeat就是将每个位置的维度都重复至指定的次数,以形成新的Tensor。

>>> b.shape
torch.Size([1, 32, 1, 1])
# .repeat()方法给定的参数是每一个dim要repeat 拷贝的次数
>>> b.repeat(4,32,1,1).shape
torch.Size([4, 1024, 1, 1])
>>> b.repeat(4,1,1,1).shape
torch.Size([4, 32, 1, 1])
>>> b.repeat(4,1,32,32).shape
torch.Size([4, 32, 32, 32])

4.转置:

需注意的是转置方法只适用于dim=2的Tensor。

>>> a = torch.randn(3,4)
>>> a
tensor([[ 1.1570, -0.1703, -0.6097,  0.4807],
        [ 0.8426, -0.5735, -1.9075,  0.3949],
        [ 0.4538, -1.5075,  0.2424,  1.1875]])
>>> a.shape
torch.Size([3, 4])
>>> a.t()
tensor([[ 1.1570,  0.8426,  0.4538],
        [-0.1703, -0.5735, -1.5075],
        [-0.6097, -1.9075,  0.2424],
        [ 0.4807,  0.3949,  1.1875]])
>>> a.t().shape
torch.Size([4, 3])
# 当dim>2时,进行转置操作会报错
>>> b = torch.randn(3,4,3)
>>> b.shape
torch.Size([3, 4, 3])
>>> b.t()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: t() expects a 2D tensor, but self is 3D

5.维度交换:

注意这种交换使得存储不再连续,再执行一些reshape的操作肯定是执行不了的,所以要调用一下contiguous()使其变成连续的维度。

>>> a = torch.randn(1,2,3,4)
>>> a.shape
torch.Size([1, 2, 3, 4])
>>> a.transpose(1,3).shape
torch.Size([1, 4, 3, 2])

因为持续的转置操作可能会导致数据变化,为了检验变换后的数据是否与原数据一致,可以采用eq()方法。

>>> a = torch.rand(3,4,5,6)
>>> b = a.transpose(1,3).contiguous().reshape(3,6*5*4).reshape(3,6,5,4).transpose(1,3)
>>> b.shape
torch.Size([3, 4, 5, 6])
>>> a.shape
torch.Size([3, 4, 5, 6])
# 使用eq()来比较数据内容是否一致,all()函数是全部数据都一致时才返回true(1)
>>> torch.all(torch.eq(a,b))
tensor(True)

6.序列改变permute:

permute()方法可以直接指定维度所处的新位置,相比transpose()方法可以很方便的操作。
例如四个维度为[batch,channel,h,w],如果想把channel放到最后的位置,形成[batch,h,w,channel],那么如果使用transpose()方法,至少要交换两次(先13交换再12交换),而使用permute()可以很方便的操作。

>>> a = torch.rand(1,2,3,4)
>>> a.shape
torch.Size([1, 2, 3, 4])
# .permute()方法的参数为所需的原来维度的位置
# 第一个位置放原来的0维,第二个位置放原来的2维,以此类推
>>> a.permute(0,2,3,1).shape
torch.Size([1, 3, 4, 2])
RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments