首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】14.Tensor的高阶选择函数

本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





01

where()函数:


用途和功能:
where()函数用于根据条件选择张量中的元素。它返回满足条件的元素的索引或根据索引选择元素。
torch.where(condition, input, other)
其中:
  • condition (BoolTensor):一个布尔张量,用于确定哪些元素满足条件。
  • input (Tensor or Scalar)如果提供了这两个参数,函数将根据conditioninput (condition为True)other(condition为False)中选择元素。

如果只提供conditionwhere()会返回满足条件的元素的索引。如果同时提供inputother,则会根据condition的形状从inputother中选择元素。

where()函数在需要根据特定条件筛选数据或根据条件从两个数据源中选择数据时非常有用。例如,在机器学习中,可能需要根据某个阈值选择模型输出的一部分。

>>> x = torch.randn(32)
>>> x
tensor([[ 1.3652, -1.4215],
        [ 0.0562,  1.1291],
        [ 1.2951, -0.5436]]
)

>>> y = torch.ones(32)
>>> y
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]]
)

# 返回x>0的元素的索引
>>> torch.where(x > 0)
(tensor([0112]), tensor([0010]))

# 当x>0时,Z在该位置取x中的值,否则取y中的值
>>> torch.where(x > 0, x, y)
tensor([[1.3652, 1.0000],
        [0.0562, 1.1291],
        [1.2951, 1.0000]]
)

02

gather()函数:


gather()函数用于根据索引张量沿着指定维度从源张量中收集切数据。实现的是一个查表映射的操作。
torch.gather(inputdimindex)
其中:
  • input (Tensor):源张量。
  • dim (int):要沿其收集的维度。
  • index (LongTensor):索引张量,具有与input维度数相同,用于指定要收集的元素的索引。输出张量与索引张量的形状

gather()函数在处理需要基于索引重新排列或选择数据的任务时非常有用。例如,在排序算法或数据重排任务中,可能需要根据计算出的索引重新排列数据。

  • 2-D tensor:
例如,我们通过神经网络得到了4个样本的每个分类的预测概率prob(数据仅为示意,因为值的总和不为1),然后需要得到前三概率的label标签数据:
>>> prob = torch.rand(4,10)
>>> prob
tensor([[0.4382, 0.1816, 0.0103, 0.8079, 0.5625, 0.3524, 0.9400, 0.7869, 0.2158,
         0.4823],
        [0.2992, 0.3542, 0.5600, 0.0941, 0.9609, 0.6544, 0.0146, 0.1859, 0.1265,
         0.7516],
        [0.1850, 0.4351, 0.1013, 0.8147, 0.8545, 0.8827, 0.9082, 0.0540, 0.2516,
         0.8884],
        [0.8478, 0.8168, 0.2865, 0.7098, 0.1472, 0.0452, 0.9378, 0.7390, 0.2482,
         0.6517]]
)

>>> idx = prob.topk(k=3,dim=1)
>>> idx
torch.return_types.topk(
values=tensor([[0.9400, 0.8079, 0.7869],
        [0.9609, 0.7516, 0.6544],
        [0.9082, 0.8884, 0.8827],
        [0.9378, 0.8478, 0.8168]]
),
indices=tensor([[6, 3, 7],
        [4, 9, 5],
        [6, 9, 5],
        [6, 0, 1]]
))

>>> index = idx[1]
>>> index
tensor([[6, 3, 7],
        [4, 9, 5],
        [6, 9, 5],
        [6, 0, 1]]
)

# 为了与索引值区分,本次乘以10
>>> label = (torch.arange(10)*10).expand(4,10)
>>> label
tensor([[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        [ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        [ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        [ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90]]
)

>>> torch.gather(label , dim=1 , index=index)
tensor([[60, 30, 70],
        [40, 90, 50],
        [60, 90, 50],
        [60,  0, 10]]
)
  • 3-D tensor:
对于 3-D 张量,输出由以下公式指定:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
例如:
>>> t = torch.randn(2,3,4)
>>> t
tensor([[[-0.5826, -0.6819, -0.7033,  0.0523],
         [-1.4312,  0.9060, -0.8302, -0.2463],
         [-0.2213, -0.0234, -1.8575, -1.4648]]
,

        [[-0.3952,  2.1159,  0.4586, -0.7825],
         [ 0.4127, -0.0337,  0.3258,  0.1194],
         [ 0.3513,  0.0743, -2.6382,  0.7267]]
])

>>> index = torch.tensor([[[2, 0],
...                        [1, 2]]
,
...                       [[1, 2],
...                        [1, 0]]
])
>>> index.shape
torch.Size([222])

>>> torch.gather(t , dim=1 , index=index)
tensor([[[-0.2213, -0.6819],
         [-1.4312, -0.0234]]
,

        [[ 0.4127,  0.0743],
         [ 0.4127,  2.1159]]
])

更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments