本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
Tensor的拼接:
-
torch.cat()函数:
-
tensors (sequence of Tensors):需要拼接的张量序列。 -
dim (int):拼接的维度索引。默认值为0,表示在第0维上拼接。
需要注意的是,torch.cat()函数要求所有输入张量在非拼接维度上的形状必须相同。
>>> a = torch.rand(4,32,8)
>>> b = torch.rand(5,32,8)
>>> torch.cat([a,b],dim=0).shape
torch.Size([9, 32, 8])
-
torch.stack()函数:
-
tensors (sequence of Tensors):需要堆叠的张量序列。 -
dim (int):插入新维度的索引。默认值为0,表示在最外层维度堆叠。
需要注意的是,torch.stack()函数要求所有输入张量的形状必须完全相同。
>>> a = torch.rand(4,3,32,32)
>>> b = torch.rand(4,3,32,32)
>>> torch.stack([a,b],dim=0).shape
torch.Size([2, 4, 3, 32, 32])
>>> torch.stack([a,b],dim=2).shape
torch.Size([4, 3, 2, 32, 32])
Tensor的拆分:
-
split()函数:
-
tensor:要拆分的输入张量。 -
split_size_or_sections :如果传入一个整数,表示每个输出张量的大小,此时最后一个张量可能小于这个整数,以确保所有元素都被拆分;如果传入一个整数列表,表示每个输出张量在拆分维度上的索引位置,此时列表中的元素之和必须等于拆分维度的大小。 -
dim:沿着哪个维度进行拆分,默认值为0。
>>> a = torch.rand(6,4,3,32,32)
>>> a.shape
torch.Size([6, 4, 3, 32, 32])
# 按大小为4拆分第0维
>>> a1 , a2 = a.split(4 , dim=0)
>>> a1.shape , a2.shape
(torch.Size([4, 4, 3, 32, 32]), torch.Size([2, 4, 3, 32, 32]))
# 按大小列表拆分第2维
>>> a3 , a4 = torch.split(a , [1,2] , dim=2)
>>> a3.shape , a4.shape
(torch.Size([6, 4, 1, 32, 32]), torch.Size([6, 4, 2, 32, 32]))
-
chunk()函数:
-
tensor:要拆分的输入张量。 -
chunks:要拆分的块数。 -
dim:沿着哪个维度进行拆分,默认值为0。
chunks有最大值限制,如果指定的拆分数量大于最大值限制,则只拆分为最大值数量的数组,chunks最大值的计算如下(在指定dim上大小为a):
>>> a = torch.rand(2,5,32,32)
>>> a.shape
torch.Size([2, 5, 32, 32])
# Situation 1
>>> a1,a2,a3,a4 = torch.chunk(a , 4 , dim=2)
>>> a1.shape,a2.shape,a3.shape,a4.shape
(torch.Size([2, 5, 8, 32]), torch.Size([2, 5, 8, 32]), torch.Size([2, 5, 8, 32]), torch.Size([2, 5, 8, 32]))
# Situation 2
>>> a1,a2,a3 = a.chunk(3,dim=1)
>>> a1.shape,a2.shape,a3.shape
(torch.Size([2, 2, 32, 32]), torch.Size([2, 2, 32, 32]), torch.Size([2, 1, 32, 32]))
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师