1.一般索引:
类似python中的数据索引,从前往后进行索引,即依次在每个维度上做索引。
>>> import torch
>>> a = torch.rand(4,3,28,28)
>>> a[0]
>>> a[0].shape
torch.Size([3, 28, 28])
>>> a[0,0].shape
torch.Size([28, 28])
>>> a[0,0,2,4].shape
torch.Size([])
# 返回的是一个scalar标量
>>> a[0,0,2,4]
tensor(0.3200)
2.一般的切片索引:
类似python中数据的切片索引。
>>> a.shape
torch.Size([4, 3, 28, 28])
# 默认从dim 0开始索引
>>> a[:2].shape
torch.Size([2, 3, 28, 28])
# 此时后面的:可以不写,即表示dim 3和dim 4索引全部
>>> a[:2,:1,:,:].shape
torch.Size([2, 1, 28, 28])
>>> a[:2,1,:,:].shape
torch.Size([2, 28, 28])
>>> a[:2,-1:,:,:].shape
torch.Size([2, 1, 28, 28])
3.使用step进行索引:
>>> a.shape
torch.Size([4, 3, 28, 28])
>>> a[:2,-1:,:,:].shape
torch.Size([2, 1, 28, 28])
# 通过设置step(步长),进行隔行采样
>>> a[:,:,0:28:2,0:28:2].shape
torch.Size([4, 3, 14, 14])
# 下面的方式等同于上面的方式
>>> a[:,:,::2,::2].shape
torch.Size([4, 3, 14, 14])
4.通过特定下标进行索引:
>>> a.shape
torch.Size([4, 3, 28, 28])
# 索引dim 0的第0和第2个数据
# .index_select(dim,index),需注意的是index需要时tensor
>>> a.index_select(0,torch.tensor([0,2])).shape
torch.Size([2, 3, 28, 28])
# 索引dim 1的第0和第2个数据
>>> a.index_select(1,torch.tensor([0,2])).shape
torch.Size([4, 2, 28, 28])
# 索引dim 2的第0-7的8个数据
>>> a.index_select(2,torch.arange(8)).shape
torch.Size([4, 3, 8, 28])
5.通过...
(任意多维度)进行索引:
...
符号表示任意多维度,具体多少维度系统会根据具体的情况进行推测。
>>> a.shape
torch.Size([4, 3, 28, 28])
# 此种情况代表四个维度都取
>>> a[...].shape
torch.Size([4, 3, 28, 28])
# dim 0取第0个,后面3个dim都取
>>> a[0,...].shape
torch.Size([3, 28, 28])
# dim 0全取,dim 1取第1个,后面2个dim都取
>>> a[:,1,...].shape
torch.Size([4, 28, 28])
# 最后一个dim取第0和第1,前面的dim都取
>>> a[...,:2].shape
torch.Size([4, 3, 28, 2])
6.通过mask(掩码)进行索引:
>>> a = torch.randn(3,4)
>>> a
tensor([[-0.5722, -0.0718, 0.0897, -0.6847],
[-0.7293, 0.7165, -1.0048, 0.7590],
[ 0.1698, 0.7343, 0.5072, 0.6891]])
# 得到>0.5的数据的掩码
>>> mask = a.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, False, True],
[False, True, True, True]], dtype=torch.uint8)
# 通过.masked_select()方法取得>0.5的数据
>>> torch.masked_select(a,mask)
tensor([0.7165, 0.7590, 0.7343, 0.5072, 0.6891])
# 通过.masked_select()方法得到的数据会被打平(dim为1),因此较少采用
>>> torch.masked_select(a,mask).shape
torch.Size([5])
6.通过take(掩码)进行索引:
take索引是基于目标Tensor的flatten形式下的,即摊平后的Tensor的索引。
>>> a = torch.tensor([[3,4,5],[6,7,8]])
>>> a
tensor([[3, 4, 5],
[6, 7, 8]])
# .take()方法是对数据进行打平之后进行的操作,即flatten之后的Tensor的索引
# 现将数据打平,得到tensor([3,4,5,6,7,8]),然后索引第0,2,3个数据
>>> torch.take(a,torch.tensor([0,2,3]))
tensor([3, 5, 6])