本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
Tensor的维度:
-
Tensor的维度计算方法:
-
shape属性:shape是Tensor对象的属性,它返回一个表示Tensor维度的元组。 -
size()方法:size()是Tensor对象的一个方法,它也可以用来获取Tensor的维度信息。
除此之外,还有以下两种方式:
-
tensor.dim():用来查看Tensor的维度,等价于len(tensor.shape); -
tensor.numel():查看Tensor中元素的数量。
>>> a = torch.rand(2,3,28,28)
# shape属性
>>> a.shape
torch.Size([2, 3, 28, 28])
# size()方法
>>> a.size()
torch.Size([2, 3, 28, 28])
# dim()方法
>>> a.dim()
4
# numel()方法
>>> a.numel() # 2*3*28*28
4704
维度变换操作:
-
view() / reshape():这两个函数用于改变张量的形状,而不改变其数据。view()在形状兼容的情况下更为灵活,而reshape()在某些情况下会返回一个新的副本。 -
squeeze() / unsqueeze():squeeze()用于移除张量中所有大小为1的维度,而unsqueeze()则用于在指定位置添加一个新的大小为1的维度。 -
expand() / broadcast_to():这些函数用于扩展张量的维度以匹配另一个张量的形状,常用于广播操作。 -
flatten() / ravel():这两个函数用于将多维张量展平成一维张量,但flatten()通常更常用。
-
view() / reshape():
>>> a = torch.rand(2,3,28,28)
>>> a.reshape(2,3,28*28).shape
torch.Size([2, 3, 784])
>>> a.view(2,3,-1).shape
torch.Size([2, 3, 784])
>>> a = torch.rand(2,3,28,28)
# 经过维度交换操作后,会使得tensor不再连续
>>> b = a.transpose(1,3)
>>> b.shape
torch.Size([2, 28, 28, 3])
# reshape方法仍然适用
>>> b.reshape(2,-1)
tensor([[0.2235, 0.1918, 0.9721, ..., 0.2302, 0.4788, 0.8718],
[0.6034, 0.8511, 0.1370, ..., 0.0080, 0.6528, 0.6254]])
# view方法会报错
>>> b.view(2,-1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
# 通过contiguous()方法将tensor变得连续,再使用view方法
>>> b.contiguous().view(2,-1)
tensor([[0.2235, 0.1918, 0.9721, ..., 0.2302, 0.4788, 0.8718],
[0.6034, 0.8511, 0.1370, ..., 0.0080, 0.6528, 0.6254]])
-
squeeze() / unsqueeze():
>>> a = torch.rand(2,3,28,28)
>>> a.shape
torch.Size([2, 3, 28, 28])
>>> a.unsqueeze(0).shape
torch.Size([1, 2, 3, 28, 28])
>>> a.unsqueeze(-1).shape
torch.Size([2, 3, 28, 28, 1])
>>> a = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([1, 32, 1, 1])
# 不给参数时,会将能删减的都删减掉(dim的size为1的位置)
>>> a.squeeze().shape
torch.Size([32])
# 给定具体的维度进行删减
>>> a.squeeze(dim=0).shape
torch.Size([32, 1, 1])
# squeeze()只会删减size为1的位置,因为此时为32,因此不会被删减,也不会报错
>>> a.squeeze(1).shape
torch.Size([1, 32, 1, 1])
-
expand() / broadcast_to():
Tensor的Broadcasting操作跟pandas中的广播操作一样,能够自动实现维度扩展,并且不需要对数据进行拷贝,进而能节省内存。
-
从最后面的维度开始匹配。
-
在前面插入若干维度。
-
将维度的size从1通过expand变到和某个Tensor相同的维度。
>>> a = torch.rand(1,3,1,1)
# 只能扩展size=1的维度,给定的参数是新的size
>>> a.expand(4,3,28,28).shape
torch.Size([4, 3, 28, 28])
# -1表示不改变该维度的size
>>> a.expand(-1,3,28,28).shape
torch.Size([1, 3, 28, 28])
# 使用expand_as()扩展为与其他张量相同的大小
>>> x = torch.rand(4,3,28,28)
>>> y = torch.rand(1,3,1,1)
>>> y.expand_as(x).shape
torch.Size([4, 3, 28, 28])
>>> a = torch.rand(1,3,1,1)
# 只能扩展size=1的维度,给定的参数是新的size
>>> torch.broadcast_to(a , (4,-1,28,28)).shape
torch.Size([4, 3, 28, 28])
# -1表示不改变该维度的size
>>> torch.broadcast_to(a , (-1,3,28,28)).shape
torch.Size([1, 3, 28, 28])
-
flatten() / ravel():
torch.ravel(input)
>>> t = torch.arange(1,7).reshape(2,3)
>>> t
tensor([[1, 2, 3],
[4, 5, 6]])
>>> t.ravel()
tensor([1, 2, 3, 4, 5, 6])
torch.flatten(input, start_dim=0, end_dim=-1)
>>> t = torch.arange(1,7).reshape(2,3)
>>> t
tensor([[1, 2, 3],
[4, 5, 6]])
# 展平成一维
>>> t.flatten()
tensor([1, 2, 3, 4, 5, 6])
# 指定展平维度
>>> t = torch.arange(1,121).reshape(2,3,4,5)
>>> t.flatten(start_dim=1 , end_dim=2).shape
torch.Size([2, 12, 5])
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师