torch.take_along_dim
torch.take_along_dim(input, indices, dim, *, out=None) → Tensor
从输入中根据维度和下标选择值,其中下标的类型必须为long tensor
这个方法往往需要和能够返回long tensor类型的下标方式一起使用。比如argmax和argsort。
t = torch.tensor([[10, 30, 20], [60, 40, 50]])
max_idx = torch.argmax(t)#返回t中最大值的下标,返回类型为long tensor
torch.take_along_dim(t, max_idx)
tensor([60])
sorted_idx = torch.argsort(t, dim=1)#返回排序下标,返回类型为long tensor
torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
[40, 50, 60]])