首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】49.Tensor的高级索引

本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





在之前文章中介绍过一些简单的Tensor的索引与切片操作:

【深度学习(PyTorch篇)】8.Tensor的索引与切片

本次将介绍一些高级索引操作。
01

None索引


None索引可以直观地表示维度的扩展(增加维度或者在特定位置插入新维度),在广播法则中充当1的作用。使用None索引本质上与使用unsqueeze方法是等价的,都能起到扩展维度的作用。
例如,有一个一维张量x,形状为(2, 3),在第0维(即最外层)增加一个新维度(这种操作在深度学习中非常有用,特别是在需要将一维向量转换为可以作为网络输入的批处理数据时):
x = torch.arange(6).reshape(2,3)
y = x[None, :]
y.shape
# 输出结果:torch.Size([1, 2, 3])

# 或者使用unsqueeze()方法
y = x.unsqueeze(0) # 等价于torch.unsqueeze(x ,0)
如果想在第二维增加一个新维度:
y = x[: , None , :]
y.shape
# 输出结果:torch.Size([2, 1, 3])

# 或者使用unsqueeze()方法
y = torch.unsqueeze(x ,1) # 等价于x.unsqueeze(1)

02

整数数组索引


整数数组索引允许用户通过整数数组来指定张量中元素的位置,从而获取这些位置的元素。这种索引方式可以非常精确地选取张量中的指定元素。
  • 单维度索引:
当对一个维度使用整数数组索引时,将根据数组中的每个整数值在该维度上获取对应的切片或元素。索引数组中的每个整数指定了要选择的数据在该维度上的位置。

例如,indices数组指定了要从tensor中选取的行索引(0和2),因此结果张量包含了这两行的元素:

a = torch.arange(12).view(3,4)

# 使用整数数组索引获取指定位置的元素  
indices = torch.tensor([0, 2])  # 指定要获取的行索引:第一行和第三行

a[torch.tensor([0,2]) , :]
# 输出结果:
# tensor([[ 0,  1,  2,  3],
#         [ 8,  9, 10, 11]])
tensor中选取的列索引(0和2):
# 使用整数数组索引获取指定位置的元素  
indices = torch.tensor([0, 2])  # 指定要获取的列索引:第一列和第三列

a[: , torch.tensor([0,2])]
# 输出结果:
# tensor([[ 0,  2],
#         [ 4,  6],
#         [ 8, 10]])

  • 多维度索引:
多维度索引允许用户通过多个索引数组来同时指定张量在多个维度上的位置,从而选取这些位置的元素。这种索引方式在处理多维张量时非常有用。
整数数组索引都有一种相对固定的模式,即tensor[index1 , index2 , index3 , … , indexN],其中N必须小于等于要索引Tensor的维度,一般情况下各个index的形状是相同的。
例如,有一个形状为(3, 4)的张量A,可以使用两个索引数组row_indicescol_indices来选择A中的元素。如果row_indices=torch.tensor([0, 2])col_indices=torch.tensor([1, 3]),那么A[row_indices, col_indices]将返回A中第0行和第2行中的第1列和第3列的元素组成的张量。
a = torch.arange(12).view(3,4)

# 创建两个索引数组,用于选择元素  
# 获取索引为 [0,1]、[1,2]、[2,0]、[0,2]的元素
indices1 = torch.tensor([[0, 1], [2, 0]])  # 选择第0个维度(行)的索引  
indices2 = torch.tensor([[1, 2], [0, 2]])  # 选择第1个维度(列)的索引 

# 使用多维数据索引选择元素  
a[indices1, indices2]
# 输出结果:
# tensor([[1, 6],
#         [8, 2]]
)

如果索引数组index的形状不完全相同,但是满足广播法则,那么它们将会自动对齐成一样的形状,从而完成整数数组索引操作。对于不满足广播法则或者不能得到相同形状的整数索引,则无法进行索引操作。

# 不同形状的index索引,满足广播法则
# 创建两个索引数组,用于选择元素  
# 获取索引为 [1,0]、[2,0]、[1,2]、[2,2]的元素
indices1 = torch.tensor([1,2])[None , :]  # 选择第0个维度(行)的索引:[[1,2]]
indices2 = torch.tensor([0,2])[: , None]  # 选择第1个维度(列)的索引:[[0],[2]]

# 使用多维数据索引选择元素  
a[indices1, indices2]
# 输出结果:
# tensor([[ 4,  8],
#         [ 6, 10]]
)


  • 混合使用高级索引和基本索引:

有时可能需要混合使用高级索引和基本索引完成索引操作,此时根据高级索引所处的位置可以分为以下两种情况:

  • 1)、所有的高级索引相邻(例如tensor[:,idx1,idx2]):

直接将所有的高级索引所在区域的维度转换成高级索引的维度,Tensor的其他维度按照基本索引正常计算。

a = torch.arange(24).view(2,3,4)

indices1 = torch.tensor([[1,0]]) # torch.Size([12])
indices2 = torch.tensor([[0,2]]) # torch.Size([12])

# a.shape[0] = 2
# indices1.shape = torch.Size([12])
# 保留a的第一个维度,后两个维度是索引维度
a[: , indices1 , indices2].shape

a[: , indices1 , indices2]
# 输出结果:
# tensor([[[ 4,  2]],
#         [[16, 14]]])
b = torch.arange(120).view(2,3,4,5)

# 保留a的第一个和最后一个维度,中间两个维度是索引维度
b[: , indices1 , indices2 , :3].shape
# 输出结果:
# torch.Size([2, 1, 2, 3])
  • 2)、高级索引被划分到不同的区域(例如tensor[idx1,:,idx2]):

此时所有的高级索引并不相邻,因此无法确定高级索引的维度应该替换Tensor的哪些维度,因此统一放在输出Tensor维度的开头,剩下部分补齐基本索引的维度。

a = torch.arange(24).view(2,3,4)

# indices1.shape = torch.Size([1, 2])
# 索引维度放在输出Tensor维度的开头,即前两个维度,
# 剩下部分补齐基本索引的维度,即最后一个维度,本次基本索引所在的原始维度为3,因此输出维度也为3
a[indices1 , : , indices2].shape
# 输出结果:
# torch.Size([1, 2, 3])
b = torch.arange(120).view(2,3,4,5)

# indices1.shape = torch.Size([1, 2]),索引维度放在输出Tensor维度的开头,即前两个维度,为【1,2】
# 剩下部分补齐基本索引的维度,即后两个维度,本次基本索引所在的原始维度为2和4,因此输出维度也为【2,4】
b[: , indices1 , : , indices2].shape
# 输出结果:
# torch.Size([1, 2, 2, 4])

另外,当当基本索引和高级索引的总数小于Tensor的维度数时,会自动补上 “…” 操作:

b = torch.arange(120).view(2,3,4,5)

indices1 = torch.tensor([[1,0]]) # torch.Size([1, 2])

# 以下三种索引方式是等价的
b[: , indices1].shape
b[: , indices1 , ...].shape
b[: , indices1 , : , :].shape
# 输出结果:
# torch.Size([2, 1, 2, 4, 5])

高级索引还可以与None索引组合使用,完成更加复杂的索引操作:

a = torch.arange(12).view(3,4)

indices1 = torch.tensor([[1,0]]) # torch.Size([1, 2])
indices2 = torch.tensor([[0,2]]) # torch.Size([1, 2])

# 索引维度放在输出Tensor维度的开头,即前两个维度,
# 本次使用了None索引,因此将在最后增加一个维度
a[indices1 , None , indices2].shape
# 输出结果:
# torch.Size([1, 2, 1])

a[indices1 , None , indices2]
# 输出结果:
# tensor([[[4],
#          [2]]])


03

布尔数组索引


布尔索引是使用布尔掩码(由TrueFalse值组成的张量)来选择元素的一种方法。掩码的形状应该与所索引的张量的形状相匹配。当掩码中的值为True时,对应的张量元素会被选中;反之,当掩码中的值为False时,对应的张量元素会被忽略。

之前介绍过通过mask(掩码)进行索引,但是该方式返回的结果是包含了所有满足条件的元素的一维张量。如果想要保持原始张量的形状,可以使用torch.where()函数来实现。而使用布尔索引通过高级索引的方式可以简化计算。

例如,对一个Tensor中所有正数元素进行乘以10操作:

a = torch.tensor([[1,-2,3] , [4,5,-6] , [-7,8,-9]])

# 使用torch.where()函数
res = torch.where(a>0 , a*10 , a)
res
# 输出结果:
# tensor([[10, -2, 30],
#         [40, 50, -6],
#         [-7, 80, -9]]
)


# 使用布尔数组索引
a[a>0] *= 10
a
# 输出结果:
# tensor([[10, -2, 30],
#         [40, 50, -6],
#         [-7, 80, -9]]
)

比如,返回一个Tensor中所有行和大于0的行:

a = torch.tensor([[1,-2,3] , [4,5,-6] , [-7,8,-9]])

row_sum = a.sum(axis=1)

# 使用torch.where()函数
# 返回满足条件的元素的索引
res = torch.where(row_sum>0)
res
# 输出结果:
# (tensor([0, 1]),)])

# 根据索引获取对应元素
a[res]
# 输出结果:
# tensor([[ 1, -2,  3],
#         [ 4,  5, -6]])


# 使用布尔索引可以简化计算
a[row_sum>0 , :]
# 输出结果:
# tensor([[ 1, -2,  3],
#         [ 4,  5, -6]])



更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments