首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】11.Tensor的拼接(merge)与拆分(split)


本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch






01

Tensor的拼接:


Pytorch中,torch.cat()torch.stack()都是用于拼接张量(Tensor)的操作,但它们在功能和使用场景上存在一些差异。
  • torch.cat()函数:
torch.cat(tensors, dim=0)函数用于将多个张量按指定的维度拼接起来。拼接操作不会增加张量的维度,而是在指定的维度上扩展。返回一个拼接后的新张量。
其中:
  • tensors (sequence of Tensors):需要拼接的张量序列。
  • dim (int):拼接的维度索引。默认值为0,表示在第0维上拼接。

需要注意的是,torch.cat()函数要求所有输入张量在非拼接维度上的形状必须相同

>>> a = torch.rand(4,32,8)
>>> b = torch.rand(5,32,8)

>>> torch.cat([a,b],dim=0).shape
torch.Size([9328])
  • torch.stack()函数:
torch.stack(tensors, dim=0)函数用于将多个张量在一个新的维度堆叠起来。堆叠操作会增加张量的维度。返回一个堆叠后的新张量。
其中
  • tensors (sequence of Tensors):需要堆叠的张量序列。
  • dim (int):插入新维度的索引。默认值为0,表示在最外层维度堆叠。

需要注意的是,torch.stack()函数要求所有输入张量的形状必须完全相同。

>>> a = torch.rand(4,3,32,32)
>>> b = torch.rand(4,3,32,32)

>>> torch.stack([a,b],dim=0).shape
torch.Size([2433232])

>>> torch.stack([a,b],dim=2).shape
torch.Size([4323232])
总结来说,torch.cat()用于在指定维度上拼接张量,不增加张量的总维度数;而torch.stack()用于在新增的维度上堆叠张量,增加张量的总维度数。在使用时,需要根据具体的需求和张量的形状来选择适当的操作。
02

Tensor的拆分:


在PyTorch中,torch/tensor.split()torch/tensor.chunk()都是用于拆分张量(Tensor)的操作。虽然它们都用于将张量拆分成多个较小的张量,但在操作方式、参数和使用场景上存在一些差异。
上述函数方法均存在模块级函数、张量级方法两类接口。
  • split()函数:
torch.split(tensor, split_size_or_sections, dim=0)函数用于将张量按指定的大小拆分成多个较小的张量。拆分操作沿着指定的维度进行。返回一个张量的元组,包含了拆分后的多个较小张量。
其中:
  • tensor:要拆分的输入张量。
  • split_size_or_sections :如果传入一个整数,表示每个输出张量的大小,此时最后一个张量可能小于这个整数,以确保所有元素都被拆分;如果传入一个整数列表,表示每个输出张量在拆分维度上的索引位置,此时列表中的元素之和必须等于拆分维度的大小。
  • dim:沿着哪个维度进行拆分,默认值为0。
>>> a = torch.rand(6,4,3,32,32)
>>> a.shape
torch.Size([6433232])

# 按大小为4拆分第0维
>>> a1 , a2 = a.split(4 , dim=0)
>>> a1.shape , a2.shape
(torch.Size([4433232]), torch.Size([2433232]))

# 按大小列表拆分第2维
>>> a3 , a4 = torch.split(a , [1,2] , dim=2)
>>> a3.shape , a4.shape
(torch.Size([6413232]), torch.Size([6423232]))
  • chunk()函数:
torch.chunk(input, chunks, dim=0)函数用于将张量按指定的块数拆分成多个较小的张量。拆分操作沿着指定的维度进行,且每个块的大小尽可能相等,但最后一个块可能小于其他块,以确保所有元素都被拆分返回一个张量的元组,包含了拆分后的多个块。
其中:
  • tensor:要拆分的输入张量。
  • chunks:要拆分的块数。
  • dim:沿着哪个维度进行拆分,默认值为0。

chunks有最大值限制,如果指定的拆分数量大于最大值限制,则只拆分为最大值数量的数组,chunks最大值的计算如下(在指定dim上大小为a):

>>> a = torch.rand(2,5,32,32)
>>> a.shape
torch.Size([253232])

# Situation 1
>>> a1,a2,a3,a4 = torch.chunk(a , 4 , dim=2)
>>> a1.shape,a2.shape,a3.shape,a4.shape
(torch.Size([25832]), torch.Size([25832]), torch.Size([25832]), torch.Size([25832]))

# Situation 2
>>> a1,a2,a3 = a.chunk(3,dim=1)
>>> a1.shape,a2.shape,a3.shape
(torch.Size([223232]), torch.Size([223232]), torch.Size([213232]))

更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments