首页人工智能Pytorch8.Tensor的高阶操作...

8.Tensor的高阶操作

1.where操作:

也就是说使用Z=torch.where(condition,x,y)
Z中的元素来自x或者来自y,是由condition中相应位置的元素是1(true)还是0(false)来决定的,需要注意的是其中x,y,Z,condition是shape相同的tensor。

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[ 1.2202, -1.1671],
        [-0.0739,  1.4761],
        [ 0.6622,  0.8530]])
>>> y
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
# 当x>0时,Z在该位置取x中的值,否则取y中的值
>>> Z =torch.where(x > 0, x, y)
>>> Z
tensor([[1.2202, 1.0000],
        [1.0000, 1.4761],
        [0.6622, 0.8530]])

2.gather操作:

使用torch.gather(input,dim,index)实现的是一个查表映射的操作。

>>> a = torch.tensor([[1,2],[3,4]])
>>> a
tensor([[1, 2],
        [3, 4]])
# 在dim 1上进行查表映射操作,input为[[1,2],[3,4]],index为[[0,0],[1,0]],因此output为[[1, 1],[4, 3]]([1,2]上都取第0个也就是1,[3,4]上分别取第1个和第0个也就是4和3)
>>> torch.gather(a, 1, torch.tensor([[0,0],[1,0]]))
tensor([[1, 1],
        [4, 3]])
>>> prob = torch.randn(4,10)
>>> idx = prob.topk(k=3,dim=1)
>>> idx = idx[1]
>>> idx
tensor([[7, 4, 8],
        [6, 5, 9],
        [0, 3, 4],
        [5, 1, 9]])
>>> label = torch.arange(10)*10
>>> label = label.expand(4,10)
>>> label
tensor([[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        [ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        [ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        [ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90]])
>>> torch.gather(label , dim=1 , index=idx)
tensor([[70, 40, 80],
        [60, 50, 90],
        [ 0, 30, 40],
        [50, 10, 90]])

更多操作可以查看PyTorch官方文档

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments