本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
Pytorch中的Tensor支持与NumPy数组类似的索引和切片操作,用于提取Tensor的部分数据或对其进行修改。
其中,切片操作返回的是原始Tensor的视图,这意味着对切片的修改会影响到原始Tensor。如果需要独立的副本,需使用clone()方法创建一个Tensor的深拷贝。
索引与切片方法:
-
1.一般索引:
>>> a = torch.rand(4,3,28,28)
>>> a[0]
>>> a[0].shape
torch.Size([3, 28, 28])
>>> a[0,0].shape
torch.Size([28, 28])
>>> a[0,0,2,4].shape
torch.Size([])
# 返回的是一个scalar标量
>>> a[0,0,2,4]
tensor(0.3200)
-
2.一般的切片索引:
>>> a.shape
torch.Size([4, 3, 28, 28])
# 默认从dim 0开始索引
>>> a[:2].shape
torch.Size([2, 3, 28, 28])
# 此时后面的:可以不写,即表示dim 3和dim 4索引全部
>>> a[:2,:1,:,:].shape
torch.Size([2, 1, 28, 28])
>>> a[:2,1,:,:].shape
torch.Size([2, 28, 28])
>>> a[:2,-1:,:,:].shape
torch.Size([2, 1, 28, 28])
-
3.使用step进行索引:
>>> a.shape
torch.Size([4, 3, 28, 28])
>>> a[:2,-1:,:,:].shape
torch.Size([2, 1, 28, 28])
# 通过设置step(步长),进行隔行采样
>>> a[:,:,0:28:2,0:28:2].shape
torch.Size([4, 3, 14, 14])
# 下面的方式等同于上面的方式
>>> a[:,:,::2,::2].shape
torch.Size([4, 3, 14, 14])
-
4.通过特定下标进行索引:
index_select() 函数会根据指定维度上的索引选择张量(tensor)进行索引切片,函数语法如下:
tensor.index_select(dim, index)
-
dim:要在其上进行索引选择的维度; -
index:索引选择张量,数据类型为LongTensor。
返回值是一个新的张量,其中包含按指定索引和维度选择的元素。
>>> a.shape
torch.Size([4, 3, 28, 28])
# 索引dim 0的第0和第2个数据
>>> a.index_select(0,torch.tensor([0,2] , dtype=torch.long)).shape
torch.Size([2, 3, 28, 28])
# 索引dim 1的第0和第2个数据
>>> a.index_select(1,torch.tensor([0,2])).shape
torch.Size([4, 2, 28, 28])
# 索引dim 2的第0-7的8个数据
>>> a.index_select(2,torch.arange(8)).shape
torch.Size([4, 3, 8, 28])
-
5.通过…(任意多维度)进行索引:
>>> a.shape
torch.Size([4, 3, 28, 28])
# 此种情况代表四个维度都取
>>> a[...].shape
torch.Size([4, 3, 28, 28])
# dim 0取第0个,后面3个dim都取
>>> a[0,...].shape
torch.Size([3, 28, 28])
# dim 0全取,dim 1取第1个,后面2个dim都取
>>> a[:,1,...].shape
torch.Size([4, 28, 28])
# 最后一个dim取第0和第1,前面的dim都取
>>> a[...,:2].shape
torch.Size([4, 3, 28, 2])
-
6.通过mask(掩码)进行索引:
在PyTorch中,可以使用布尔掩码(mask)来索引张量(tensor)。掩码是一个与张量形状相同或可广播到相同形状的布尔张量,其中的每个元素都是True或False。通过掩码索引,你可以选择张量中满足特定条件的元素。
主要注意的是,当使用掩码索引时,返回的张量是一维的,并且包含了所有满足条件的元素。如果想要保持原始张量的形状,可以使用torch.where()函数来实现:
>>> x = torch.randn(3,4)
>>> x
tensor([[-0.5722, -0.0718, 0.0897, -0.6847],
[-0.7293, 0.7165, -1.0048, 0.7590],
[ 0.1698, 0.7343, 0.5072, 0.6891]])
# 得到>0.5的数据的掩码
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, False, True],
[False, True, True, True]], dtype=torch.uint8)
# 使用torch.masked_select()函数选择
>>> torch.masked_select(x,mask)
tensor([0.7165, 0.7590, 0.7343, 0.5072, 0.6891])
# 得到的数据会被打平(dim为1)
>>> torch.masked_select(x,mask).shape
torch.Size([5])
>>> x[mask]
tensor([2.2139, 2.1848, 0.5279, 0.5202, 2.0046])
>>> x[mask].shape
torch.Size([5])
-
7.通过take(掩码)进行索引:
>>> x = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
>>> x
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# take()方法是对数据进行打平之后进行的操作
# 先将数据打平,得到tensor([1,2,3,4,5,6,7,8,9]),然后索引0,2,5,7个数据
>>> torch.take(a,torch.tensor([0,2,5,7]))
tensor([1, 3, 6, 8])
应用举例:
在数据预处理阶段,索引和切片操作常用于从数据集中提取特定的样本或特征。
# 假设有一个形状为(1000, 10, 28, 28)的数据集,表示1000个样本,每个样本有10个通道,每个通道是28x28的图像
data = torch.rand(1000, 10, 28, 28)
# 提取第5个样本的所有通道
sample = data[4] # 形状为(10, 28, 28)
# 提取前100个样本的第1个通道
channel = data[:100, 0] # 形状为(100, 28, 28)
在数据处理阶段,索引和切片操作可以用于实现各种数据增强技术,如随机裁剪、翻转等。
# 随机裁剪一个图像
image = torch.rand(3, 224, 224) # 假设有一个3通道224x224的图像
crop_size = (200, 200)
start = (torch.rand(2) * (224 - 200)).long() # 随机计算裁剪的起始位置
cropped_image = image[:, start[0]:start[0]+crop_size[0], start[1]:start[1]+crop_size[1]]
在模型训练阶段,索引和切片操作可以用于提取一个batch中的样本,或者从模型的输出中提取特定的信息。
# 假设模型的输出是一个形状为(batch_size, num_classes)的Tensor
outputs = torch.rand(64, 10)
# 提取第一个样本的输出
first_output = outputs[0]
# 提取一个batch中所有样本的第3个类别的输出
third_class_outputs = outputs[:, 2]
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师