本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
【深度学习(PyTorch篇)】8.Tensor的索引与切片
None索引:
x = torch.arange(6).reshape(2,3)
y = x[None, :]
y.shape
# 输出结果:torch.Size([1, 2, 3])
# 或者使用unsqueeze()方法
y = x.unsqueeze(0) # 等价于torch.unsqueeze(x ,0)
y = x[: , None , :]
y.shape
# 输出结果:torch.Size([2, 1, 3])
# 或者使用unsqueeze()方法
y = torch.unsqueeze(x ,1) # 等价于x.unsqueeze(1)
整数数组索引:
-
单维度索引:
例如,indices数组指定了要从tensor中选取的行索引(0和2),因此结果张量包含了这两行的元素:
a = torch.arange(12).view(3,4)
# 使用整数数组索引获取指定位置的元素
indices = torch.tensor([0, 2]) # 指定要获取的行索引:第一行和第三行
a[torch.tensor([0,2]) , :]
# 输出结果:
# tensor([[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]])
# 使用整数数组索引获取指定位置的元素
indices = torch.tensor([0, 2]) # 指定要获取的列索引:第一列和第三列
a[: , torch.tensor([0,2])]
# 输出结果:
# tensor([[ 0, 2],
# [ 4, 6],
# [ 8, 10]])
-
多维度索引:
a = torch.arange(12).view(3,4)
# 创建两个索引数组,用于选择元素
# 获取索引为 [0,1]、[1,2]、[2,0]、[0,2]的元素
indices1 = torch.tensor([[0, 1], [2, 0]]) # 选择第0个维度(行)的索引
indices2 = torch.tensor([[1, 2], [0, 2]]) # 选择第1个维度(列)的索引
# 使用多维数据索引选择元素
a[indices1, indices2]
# 输出结果:
# tensor([[1, 6],
# [8, 2]])
如果索引数组index的形状不完全相同,但是满足广播法则,那么它们将会自动对齐成一样的形状,从而完成整数数组索引操作。对于不满足广播法则或者不能得到相同形状的整数索引,则无法进行索引操作。
# 不同形状的index索引,满足广播法则
# 创建两个索引数组,用于选择元素
# 获取索引为 [1,0]、[2,0]、[1,2]、[2,2]的元素
indices1 = torch.tensor([1,2])[None , :] # 选择第0个维度(行)的索引:[[1,2]]
indices2 = torch.tensor([0,2])[: , None] # 选择第1个维度(列)的索引:[[0],[2]]
# 使用多维数据索引选择元素
a[indices1, indices2]
# 输出结果:
# tensor([[ 4, 8],
# [ 6, 10]])
-
混合使用高级索引和基本索引:
有时可能需要混合使用高级索引和基本索引完成索引操作,此时根据高级索引所处的位置可以分为以下两种情况:
-
1)、所有的高级索引相邻(例如tensor[:,idx1,idx2]):
直接将所有的高级索引所在区域的维度转换成高级索引的维度,Tensor的其他维度按照基本索引正常计算。
a = torch.arange(24).view(2,3,4)
indices1 = torch.tensor([[1,0]]) # torch.Size([1, 2])
indices2 = torch.tensor([[0,2]]) # torch.Size([1, 2])
# a.shape[0] = 2
# indices1.shape = torch.Size([1, 2])
# 保留a的第一个维度,后两个维度是索引维度
a[: , indices1 , indices2].shape
a[: , indices1 , indices2]
# 输出结果:
# tensor([[[ 4, 2]],
# [[16, 14]]])
b = torch.arange(120).view(2,3,4,5)
# 保留a的第一个和最后一个维度,中间两个维度是索引维度
b[: , indices1 , indices2 , :3].shape
# 输出结果:
# torch.Size([2, 1, 2, 3])
-
2)、高级索引被划分到不同的区域(例如tensor[idx1,:,idx2]):
此时所有的高级索引并不相邻,因此无法确定高级索引的维度应该替换Tensor的哪些维度,因此统一放在输出Tensor维度的开头,剩下部分补齐基本索引的维度。
a = torch.arange(24).view(2,3,4)
# indices1.shape = torch.Size([1, 2])
# 索引维度放在输出Tensor维度的开头,即前两个维度,
# 剩下部分补齐基本索引的维度,即最后一个维度,本次基本索引所在的原始维度为3,因此输出维度也为3
a[indices1 , : , indices2].shape
# 输出结果:
# torch.Size([1, 2, 3])
b = torch.arange(120).view(2,3,4,5)
# indices1.shape = torch.Size([1, 2]),索引维度放在输出Tensor维度的开头,即前两个维度,为【1,2】
# 剩下部分补齐基本索引的维度,即后两个维度,本次基本索引所在的原始维度为2和4,因此输出维度也为【2,4】
b[: , indices1 , : , indices2].shape
# 输出结果:
# torch.Size([1, 2, 2, 4])
另外,当当基本索引和高级索引的总数小于Tensor的维度数时,会自动补上 “…” 操作:
b = torch.arange(120).view(2,3,4,5)
indices1 = torch.tensor([[1,0]]) # torch.Size([1, 2])
# 以下三种索引方式是等价的
b[: , indices1].shape
b[: , indices1 , ...].shape
b[: , indices1 , : , :].shape
# 输出结果:
# torch.Size([2, 1, 2, 4, 5])
高级索引还可以与None索引组合使用,完成更加复杂的索引操作:
a = torch.arange(12).view(3,4)
indices1 = torch.tensor([[1,0]]) # torch.Size([1, 2])
indices2 = torch.tensor([[0,2]]) # torch.Size([1, 2])
# 索引维度放在输出Tensor维度的开头,即前两个维度,
# 本次使用了None索引,因此将在最后增加一个维度
a[indices1 , None , indices2].shape
# 输出结果:
# torch.Size([1, 2, 1])
a[indices1 , None , indices2]
# 输出结果:
# tensor([[[4],
# [2]]])
布尔数组索引:
布尔索引是使用布尔掩码(由True和False值组成的张量)来选择元素的一种方法。掩码的形状应该与所索引的张量的形状相匹配。当掩码中的值为True时,对应的张量元素会被选中;反之,当掩码中的值为False时,对应的张量元素会被忽略。
之前介绍过通过mask(掩码)进行索引,但是该方式返回的结果是包含了所有满足条件的元素的一维张量。如果想要保持原始张量的形状,可以使用torch.where()函数来实现。而使用布尔索引通过高级索引的方式可以简化计算。
例如,对一个Tensor中所有正数元素进行乘以10操作:
a = torch.tensor([[1,-2,3] , [4,5,-6] , [-7,8,-9]])
# 使用torch.where()函数
res = torch.where(a>0 , a*10 , a)
res
# 输出结果:
# tensor([[10, -2, 30],
# [40, 50, -6],
# [-7, 80, -9]])
# 使用布尔数组索引
a[a>0] *= 10
a
# 输出结果:
# tensor([[10, -2, 30],
# [40, 50, -6],
# [-7, 80, -9]])
比如,返回一个Tensor中所有行和大于0的行:
a = torch.tensor([[1,-2,3] , [4,5,-6] , [-7,8,-9]])
row_sum = a.sum(axis=1)
# 使用torch.where()函数
# 返回满足条件的元素的索引
res = torch.where(row_sum>0)
res
# 输出结果:
# (tensor([0, 1]),)])
# 根据索引获取对应元素
a[res]
# 输出结果:
# tensor([[ 1, -2, 3],
# [ 4, 5, -6]])
# 使用布尔索引可以简化计算
a[row_sum>0 , :]
# 输出结果:
# tensor([[ 1, -2, 3],
# [ 4, 5, -6]])
更多内容可以前往官网查看:
https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师