首页人工智能Pytorch6.Tensor的Broa...

6.Tensor的Broadcasting、拼接与拆分、基本运算

1.Tensor的Broadcasting操作:

Broadcasting跟pandas中的广播操作一样,能够自动实现维度扩展,并且不需要对数据进行拷贝,进而能节省内存。

  • 从最后面的维度开始匹配。
  • 在前面插入若干维度。
  • 将维度的size从1通过expand变到和某个Tensor相同的维度。

总之,Broadcasting操作就是自动实现了若干unsqueeze和expand操作,以使两个tensor的shape一致,从而完成某些操作(往往是加法)。

# Situation 1:(1,32,1,1)-->(4,32,14,14),先复制再叠加
>>> a = torch.rand(4,32,14,14)
>>> b = torch.rand(1,32,1,1)
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> b.shape
torch.Size([1, 32, 1, 1])
>>> (a+b).shape
torch.Size([4, 32, 14, 14])
# Situation 2:(14,14)-->(1,1,14,14)-->(4,32,14,14),先扩展在复制再叠加
>>> c = torch.rand(14,14)
>>> c.shape
torch.Size([14, 14])
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> (a+c).shape
torch.Size([4, 32, 14, 14])
# Situation 3:(4,32,14,14)+(2,32,14,14),此时不符合Broadcasting操作规则,因此会报错
# dim 1 有dim并且不是size 1,不能insert和expand到相同维度
>>> d = torch.rand(2,32,14,14)
>>> d.shape
torch.Size([2, 32, 14, 14])
>>> a.shape
torch.Size([4, 32, 14, 14])
>>> (a+d).shaoe
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

2.Tensor的拼接(merge)与拆分(split):

  • 维度合并cat

cat()操作需要保证其它维度的size是相同的。

>>> a = torch.rand(4,32,8)
>>> b = torch.rand(5,32,8)
# 第一个参数为包含所有需要操作的tensor的list,第二个参数dim决定了在哪个维度进行合并
>>> torch.cat([a,b],dim=0).shape
torch.Size([9, 32, 8])
  • 合并新增stack

stack()操作需要保证两个tensor的shape是一致的,这就像是有两个种类的东西,它们的属性都是一样的,stack()会在指定的维度位置置前插入一个新的维度,因为是两类东西合并过来的所以这个新的维度size是2,通过指定这个维度是0或者1来分别选择。

>>> a = torch.rand(4,3,32,32)
>>> b = torch.rand(4,3,32,32)
>>> torch.stack([a,b],dim=0).shape
torch.Size([2, 4, 3, 32, 32])
>>> torch.stack([a,b],dim=2).shape
torch.Size([4, 3, 2, 32, 32])
  • 按照size的长度拆分split

对一个tensor而言,在给定的dim上进行拆分。

# Situation 1:
>>> a = torch.rand(2,4,3,32,32)
>>> a.shape
torch.Size([2, 4, 3, 32, 32])
>>> a1,a2 = a.split(1,dim=0)
>>> a1.shape
torch.Size([1, 4, 3, 32, 32])
>>> a2.shape
torch.Size([1, 4, 3, 32, 32])
# Situation 2:
>>> a = torch.rand(2,4,3,32,32)
>>> a.shape
torch.Size([2, 4, 3, 32, 32])
>>> a1,a2 = a.split([1,2],dim=2)
>>> a1.shape
torch.Size([2, 4, 1, 32, 32])
>>> a2.shape
torch.Size([2, 4, 2, 32, 32])
  • 按照份数等量拆分chunk

给定在指定的维度上要拆分得的份数,就会按照指定的份数尽量等量地进行拆分。函数返回的是由拆分后的数组组成的元祖类型;如果输入的数组再指定维度下不能被整除,则拆分得到的最后一个数组的dim维度大小将小于前面数组的维度大小。

chunk(input , chunks , dim)

input:待拆分的数组;

chunks:拆分数量;

dim:拆分的维度,默认沿第一维度拆分。

chunks有最大值限制,如果指定的拆分数量大于最大值限制,则只拆分为最大值数量的数组,chunks最大值的计算如下(在指定dim上大小为a):

# Situation 1:
>>> a = torch.rand(2,5,32,32)
>>> a.shape
torch.Size([2, 5, 32, 32])
>>> a1,a2,a3,a4 = a.chunk(4,dim=2)
>>> a1.shape,a2.shape,a3.shape,a4.shape
(torch.Size([2, 5, 8, 32]), torch.Size([2, 5, 8, 32]), torch.Size([2, 5, 8, 32]), torch.Size([2, 5, 8, 32]))
# Situation 2:
>>> a = torch.rand(2,5,32,32)
>>> a.shape
torch.Size([2, 5, 32, 32])
>>> a1,a2,a3 = a.chunk(3,dim=1)
>>> a1.shape,a2.shape,a3.shape
(torch.Size([2, 2, 32, 32]), torch.Size([2, 2, 32, 32]), torch.Size([2, 1, 32, 32]))

3.Tensor的基本运算:

  • 加add/减sub/乘mul/除div

可以看到通过使用运算符号和使用函数方法,得到的结果是一致的。

>>> a = torch.rand(3,4)
>>> b = torch.rand(4)
>>> torch.all(torch.eq(a+b,torch.add(a,b)))
tensor(True)
>>> torch.all(torch.eq(a-b,torch.sub(a,b)))
tensor(True)
>>> torch.all(torch.eq(a*b,torch.mul(a,b)))
tensor(True)
>>> torch.all(torch.eq(a/b,torch.div(a,b)))
tensor(True)
  • 矩阵相乘matmul
    • torch.mm(只适合于dim=2的tensor,因此不推荐使用);
    • torch.matmul(),推荐使用此方法;
    • 符号@。

dim=2时:

>>> a = torch.rand(2,2)
>>> a
tensor([[0.5080, 0.8932],
        [0.5370, 0.2404]])
>>> b = torch.ones(2,2)
>>> b
tensor([[1., 1.],
        [1., 1.]])
>>> torch.mm(a,b)
tensor([[1.4011, 1.4011],
        [0.7774, 0.7774]])
>>> torch.matmul(a,b)
tensor([[1.4011, 1.4011],
        [0.7774, 0.7774]])
>>> a@b
tensor([[1.4011, 1.4011],
        [0.7774, 0.7774]])

dim>2时:
对于高维的tensor,其矩阵乘法仅在最后的两个维度上,这就要求前面的维度必须保持一致。

# Situation 1:前面维度一致
>>> a = torch.rand(4,3,28,64)
>>> b = torch.rand(4,3,64,32)
>>> torch.matmul(a,b).shape
torch.Size([4, 3, 28, 32])
# Situation 2:前面维度不一致,但是符合Broadcasting机制,会自动做广播,然后相乘
>>> a = torch.rand(4,3,28,64)
>>> b = torch.rand(4,1,64,32)
>>> torch.matmul(a,b).shape
torch.Size([4, 3, 28, 32])
  • 幂运算pow/平方根运算sqrt/平方根倒数运算rsqrt
# 使用full()创建一个全部都是3的,shape是2*2的tensor
>>> a = torch.full([2,2],3)
>>> a
tensor([[3., 3.],
        [3., 3.]])
# 幂运算
>>> a.pow(3)
tensor([[27., 27.],
        [27., 27.]])
>>> a**3
tensor([[27., 27.],
        [27., 27.]])
# 平方根运算
>>> a.sqrt()
tensor([[1.7321, 1.7321],
        [1.7321, 1.7321]])
>>> a**0.5
tensor([[1.7321, 1.7321],
        [1.7321, 1.7321]])
# 平方根倒数运算
>>> a.rsqrt()
tensor([[0.5774, 0.5774],
        [0.5774, 0.5774]])
  • 指数运算/对数运算
>>> a = torch.ones(2,2)
>>> a
tensor([[1., 1.],
        [1., 1.]])
# e为底的指数函数运算
>>> b = torch.exp(a)
>>> b
tensor([[2.7183, 2.7183],
        [2.7183, 2.7183]])
# 对数运算,log是以自然对数为底数的,以2为底的用log2,以10为底的用log10
>>> torch.log(b)
tensor([[1., 1.],
        [1., 1.]])
  • 近似值运算floor/ceil/trunc/frac/round
>>> a = torch.tensor(3.14)
>>> a
tensor(3.1400)
# floor()往下取整数,a.ceil()往上取整数,a.trunc()取数据的整数部分,a.frac()取数据的小数部分
>>> a.floor(),a.ceil(),a.trunc(),a.frac()
(tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))
# round()为四舍五入
>>> a = torch.tensor(3.499)
>>> a.round()
tensor(3.)
>>> a = torch.tensor(3.5)
>>> a.round()
tensor(4.)
  • 裁剪clamp

对tensor中的元素进行范围过滤,不符合条件的可以把它变换到范围内部(边界)上,常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理。

>>> grade = torch.rand(2,3)*15
>>> grade
tensor([[6.8350, 1.1972, 9.4472],
        [9.2184, 6.6084, 8.2121]])
>>> grade.max()
tensor(9.4472)
>>> grade.median()
tensor(6.8350)
# clamp(7)将<7的数据变为7
>>> grade.clamp(7)
tensor([[7.0000, 7.0000, 9.4472],
        [9.2184, 7.0000, 8.2121]])
# clamp(0,7)将数据限定在(2,7),将<2的数据变为2,将>7的数据变为7
>>> grade.clamp(2,7)
tensor([[6.8350, 2.0000, 7.0000],
        [7.0000, 6.6084, 7.0000]])
RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments