本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
where()函数:
torch.where(condition, input, other)
-
condition (BoolTensor):一个布尔张量,用于确定哪些元素满足条件。 -
input (Tensor or Scalar):如果提供了这两个参数,函数将根据condition从input (condition为True)和other(condition为False)中选择元素。
如果只提供condition,where()会返回满足条件的元素的索引。如果同时提供input和other,则会根据condition的形状从input和other中选择元素。
where()函数在需要根据特定条件筛选数据或根据条件从两个数据源中选择数据时非常有用。例如,在机器学习中,可能需要根据某个阈值选择模型输出的一部分。
>>> x = torch.randn(3, 2)
>>> x
tensor([[ 1.3652, -1.4215],
[ 0.0562, 1.1291],
[ 1.2951, -0.5436]])
>>> y = torch.ones(3, 2)
>>> y
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
# 返回x>0的元素的索引
>>> torch.where(x > 0)
(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 0]))
# 当x>0时,Z在该位置取x中的值,否则取y中的值
>>> torch.where(x > 0, x, y)
tensor([[1.3652, 1.0000],
[0.0562, 1.1291],
[1.2951, 1.0000]])
gather()函数:
torch.gather(input, dim, index)
-
input (Tensor):源张量。 -
dim (int):要沿其收集的维度。 -
index (LongTensor):索引张量,具有与input维度数相同,用于指定要收集的元素的索引。输出张量与索引张量的形状相同。
gather()函数在处理需要基于索引重新排列或选择数据的任务时非常有用。例如,在排序算法或数据重排任务中,可能需要根据计算出的索引重新排列数据。
-
2-D tensor:
>>> prob = torch.rand(4,10)
>>> prob
tensor([[0.4382, 0.1816, 0.0103, 0.8079, 0.5625, 0.3524, 0.9400, 0.7869, 0.2158,
0.4823],
[0.2992, 0.3542, 0.5600, 0.0941, 0.9609, 0.6544, 0.0146, 0.1859, 0.1265,
0.7516],
[0.1850, 0.4351, 0.1013, 0.8147, 0.8545, 0.8827, 0.9082, 0.0540, 0.2516,
0.8884],
[0.8478, 0.8168, 0.2865, 0.7098, 0.1472, 0.0452, 0.9378, 0.7390, 0.2482,
0.6517]])
>>> idx = prob.topk(k=3,dim=1)
>>> idx
torch.return_types.topk(
values=tensor([[0.9400, 0.8079, 0.7869],
[0.9609, 0.7516, 0.6544],
[0.9082, 0.8884, 0.8827],
[0.9378, 0.8478, 0.8168]]),
indices=tensor([[6, 3, 7],
[4, 9, 5],
[6, 9, 5],
[6, 0, 1]]))
>>> index = idx[1]
>>> index
tensor([[6, 3, 7],
[4, 9, 5],
[6, 9, 5],
[6, 0, 1]])
# 为了与索引值区分,本次乘以10
>>> label = (torch.arange(10)*10).expand(4,10)
>>> label
tensor([[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
[ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90]])
>>> torch.gather(label , dim=1 , index=index)
tensor([[60, 30, 70],
[40, 90, 50],
[60, 90, 50],
[60, 0, 10]])
-
3-D tensor:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
>>> t = torch.randn(2,3,4)
>>> t
tensor([[[-0.5826, -0.6819, -0.7033, 0.0523],
[-1.4312, 0.9060, -0.8302, -0.2463],
[-0.2213, -0.0234, -1.8575, -1.4648]],
[[-0.3952, 2.1159, 0.4586, -0.7825],
[ 0.4127, -0.0337, 0.3258, 0.1194],
[ 0.3513, 0.0743, -2.6382, 0.7267]]])
>>> index = torch.tensor([[[2, 0],
... [1, 2]],
... [[1, 2],
... [1, 0]]])
>>> index.shape
torch.Size([2, 2, 2])
>>> torch.gather(t , dim=1 , index=index)
tensor([[[-0.2213, -0.6819],
[-1.4312, -0.0234]],
[[ 0.4127, 0.0743],
[ 0.4127, 2.1159]]])
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师