1.where操作:
也就是说使用Z=torch.where(condition,x,y)
Z中的元素来自x或者来自y,是由condition中相应位置的元素是1(true)还是0(false)来决定的,需要注意的是其中x,y,Z,condition是shape相同的tensor。
>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[ 1.2202, -1.1671],
[-0.0739, 1.4761],
[ 0.6622, 0.8530]])
>>> y
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
# 当x>0时,Z在该位置取x中的值,否则取y中的值
>>> Z =torch.where(x > 0, x, y)
>>> Z
tensor([[1.2202, 1.0000],
[1.0000, 1.4761],
[0.6622, 0.8530]])
2.gather操作:
使用torch.gather(input,dim,index)
实现的是一个查表映射的操作。
>>> a = torch.tensor([[1,2],[3,4]])
>>> a
tensor([[1, 2],
[3, 4]])
# 在dim 1上进行查表映射操作,input为[[1,2],[3,4]],index为[[0,0],[1,0]],因此output为[[1, 1],[4, 3]]([1,2]上都取第0个也就是1,[3,4]上分别取第1个和第0个也就是4和3)
>>> torch.gather(a, 1, torch.tensor([[0,0],[1,0]]))
tensor([[1, 1],
[4, 3]])
>>> prob = torch.randn(4,10)
>>> idx = prob.topk(k=3,dim=1)
>>> idx = idx[1]
>>> idx
tensor([[7, 4, 8],
[6, 5, 9],
[0, 3, 4],
[5, 1, 9]])
>>> label = torch.arange(10)*10
>>> label = label.expand(4,10)
>>> label
tensor([[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90]])
>>> torch.gather(label , dim=1 , index=idx)
tensor([[70, 40, 80],
[60, 50, 90],
[ 0, 30, 40],
[50, 10, 90]])
更多操作可以查看PyTorch官方文档。