1.维度变换:
调整Tensor的shape,新版本中提供了reshape()方法,在老版本中这个函数是view(),两者功能上都是一样的。
>>> a = torch.rand(4,1,28,28)
>>> a.shape
torch.Size([4, 1, 28, 28])
>>> a.reshape(4,28*28)
tensor([[0.8371, 0.9728, 0.5271, ..., 0.3528, 0.8527, 0.6425],
[0.5836, 0.0016, 0.3612, ..., 0.8342, 0.5880, 0.9205],
[0.6959, 0.9789, 0.6670, ..., 0.4895, 0.2700, 0.4503],
[0.0050, 0.3783, 0.7148, ..., 0.4339, 0.3403, 0.2258]])
>>> a.reshape(4,28*28).shape
torch.Size([4, 784])
>>> a.reshape(4*28,28).shape
torch.Size([112, 28])
>>> a.reshape(4*1,28,28).shape
torch.Size([4, 28, 28])
2.增加与删减维度:
- unsqueeze,增加维度
正数表示在该维度原本的位置前面插入这个新增加的维度,负数表示在该维度原本的位置之后插入。
>>> a.shape
torch.Size([4, 1, 28, 28])
>>> a.unsqueeze(0).shape
torch.Size([1, 4, 1, 28, 28])
>>> a.unsqueeze(-1).shape
torch.Size([4, 1, 28, 28, 1])
>>> a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])
>>> a.unsqueeze(-4).shape
torch.Size([4, 1, 1, 28, 28])
>>> a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])
# 数据维度范围为[-a.dim()-1,a.dim()+1)
# 因为5不在范围内,所以会报错
>>> a.unsqueeze(5).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
# 此时生成了一个dim 0的tensor
>>> a = torch.tensor([1.2,3.4])
>>> a.shape
torch.Size([2])
# 在最后位置的后面增加一个维度,shape将从[2]变为[2, 1]
>>> a.unsqueeze(-1)
tensor([[1.2000],
[3.4000]])
>>> a.unsqueeze(-1).shape
torch.Size([2, 1])
# 在0位置之前增加一个维度,shape将从[2]变为[2, 1]
>>> a.unsqueeze(0)
tensor([[1.2000, 3.4000]])
>>> a.unsqueeze(0).shape
torch.Size([1, 2])
- squeeze,删减维度
删减维度实际上是一个压榨的过程,直观地看是把那些多余的[]
给去掉,也就是只是去删除那些size=1的维度。
>>> a = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([1, 32, 1, 1])
# .squeeze()不给参数时,会将能删减的都删减掉(dim的size为1的位置)
>>> a.squeeze().shape
torch.Size([32])
# 给定具体的维度进行删减
>>> a.squeeze(0).shape
torch.Size([32, 1, 1])
>>> a.squeeze(-1).shape
torch.Size([1, 32, 1])
# .squeeze()只会删减size为1的位置,因为此时size为32,因此不会被删减,也不会报错
>>> a.squeeze(1).shape
torch.Size([1, 32, 1, 1])
>>> a.squeeze(-4).shape
torch.Size([32, 1, 1])
3.扩展维度:
- expand,不会主动的复制数据,因此执行速度快并且节约内存,推荐使用此方法
expand就是在某个size=1的维度上改变size,改成更大的一个size,实际就是在每个size=1的维度上的标量的广播操作。
>>> a = torch.rand(4,32,14,14)
>>> b = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> b.shape
torch.Size([1, 32, 1, 1])
# .expand()只能扩展size=1的维度,给定的参数是新的size
>>> b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
# -1表示不改变该维度的size
>>> b.expand(-1,32,-1,14).shape
torch.Size([1, 32, 1, 14])
# -4
>>> b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
<!-- wp:code {"lineNumbers":true} -->
<pre class="wp-block-code"><code lang="python" class="language-python line-numbers">>>> a = torch.rand(4,32,14,14)
>>> b = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> b.shape
torch.Size([1, 32, 1, 1])
# 使用expand_as()扩展为与其他张量相同的大小
>>> b.expand_as(a).shape
torch.Size([4, 32, 14, 14])
# .expand()只能扩展size=1的维度,给定的参数是新的size
>>> b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
# -1表示不改变该维度的size
>>> b.expand(-1,32,-1,14).shape
torch.Size([1, 32, 1, 14])
# 一个小bug
>>> b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
- repeat,该方式会重新申请内存空间,主动复制数据
repeat就是将每个位置的维度都重复至指定的次数,以形成新的Tensor。
>>> b.shape
torch.Size([1, 32, 1, 1])
# .repeat()方法给定的参数是每一个dim要repeat 拷贝的次数
>>> b.repeat(4,32,1,1).shape
torch.Size([4, 1024, 1, 1])
>>> b.repeat(4,1,1,1).shape
torch.Size([4, 32, 1, 1])
>>> b.repeat(4,1,32,32).shape
torch.Size([4, 32, 32, 32])
4.转置:
需注意的是转置方法只适用于dim=2的Tensor。
>>> a = torch.randn(3,4)
>>> a
tensor([[ 1.1570, -0.1703, -0.6097, 0.4807],
[ 0.8426, -0.5735, -1.9075, 0.3949],
[ 0.4538, -1.5075, 0.2424, 1.1875]])
>>> a.shape
torch.Size([3, 4])
>>> a.t()
tensor([[ 1.1570, 0.8426, 0.4538],
[-0.1703, -0.5735, -1.5075],
[-0.6097, -1.9075, 0.2424],
[ 0.4807, 0.3949, 1.1875]])
>>> a.t().shape
torch.Size([4, 3])
# 当dim>2时,进行转置操作会报错
>>> b = torch.randn(3,4,3)
>>> b.shape
torch.Size([3, 4, 3])
>>> b.t()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: t() expects a 2D tensor, but self is 3D
5.维度交换:
注意这种交换使得存储不再连续,再执行一些reshape的操作肯定是执行不了的,所以要调用一下contiguous()
使其变成连续的维度。
>>> a = torch.randn(1,2,3,4)
>>> a.shape
torch.Size([1, 2, 3, 4])
>>> a.transpose(1,3).shape
torch.Size([1, 4, 3, 2])
因为持续的转置操作可能会导致数据变化,为了检验变换后的数据是否与原数据一致,可以采用eq()
方法。
>>> a = torch.rand(3,4,5,6)
>>> b = a.transpose(1,3).contiguous().reshape(3,6*5*4).reshape(3,6,5,4).transpose(1,3)
>>> b.shape
torch.Size([3, 4, 5, 6])
>>> a.shape
torch.Size([3, 4, 5, 6])
# 使用eq()来比较数据内容是否一致,all()函数是全部数据都一致时才返回true(1)
>>> torch.all(torch.eq(a,b))
tensor(True)
6.序列改变permute:
permute()方法可以直接指定维度所处的新位置,相比transpose()方法可以很方便的操作。
例如四个维度为[batch,channel,h,w],如果想把channel放到最后的位置,形成[batch,h,w,channel],那么如果使用transpose()方法,至少要交换两次(先13交换再12交换),而使用permute()可以很方便的操作。
>>> a = torch.rand(1,2,3,4)
>>> a.shape
torch.Size([1, 2, 3, 4])
# .permute()方法的参数为所需的原来维度的位置
# 第一个位置放原来的0维,第二个位置放原来的2维,以此类推
>>> a.permute(0,2,3,1).shape
torch.Size([1, 3, 4, 2])