torch.gather是把tensor A的值基于dim顺序,根据index取出来;
torch.scatter是把tensor A的值基于dim顺序,根据index替换为src中的值;
torch.scatter_reduce是把tensor A的值基于dim顺序,根据index取出后,与src对应的值做reduce聚合。(注意:torch.scatter_reduce在torch>=1.13才有,否则建议使用torch_scatter包里的scatter函数)
举例torch.scatter:
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
更新方式:
self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
代码块中等号左边:
dim=?表示在self第?维度上取index的值,self其他维度取index所在的索引对应的值;
代码块中等号右边:
赋值的值,来自src。index中每个值所在的位置,对应src所在的位置。(因此要求index.shape <= src.shape)src取相应的值与index的值无关,只与index的位置(索引)有关。
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
'''
dim=0,在self第0维度对应的索引放index的值,其他维度的索引放index对应的索引;
src各维度放index对应的索引。
步骤:
index第一个元素:
dim=0,在self第0维度对应的索引放index的值,即0,得到:
self[0];
其他维度的索引放index对应的索引,(dim=0有索引了,只差dim=1的索引),得到:
self[0][0];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][0],得到:
self[0][0]=src[0][0]=1;
index第二个元素:
dim=0,在self第0维度对应的索引放index的值,即1,得到:
self[1];
其他维度的索引放index对应的索引,(dim=0有索引了,只差dim=1的索引,该索引为1),得到:
self[1][1];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][1],得到:
self[1][1]=src[0][1]=2;
...
index第四个元素:
dim=0,在self第0维度对应的索引放index的值,即0,得到:
self[0];
其他维度的索引放index对应的索引,(dim=0有索引了,只差dim=1的索引,该索引为4),得到:
self[0][4];
赋予的值是 src各维度放index对应的索引,index第四个元素所在位置为[0][4],得到:
self[0][4]=src[0][4]=4;
'''
第二个例子
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
'''
dim=1,在self第1维度对应的索引放index的值,其他维度的索引放index对应的索引;
src各维度放index对应的索引。
步骤:
index第一个元素:
dim=1,在self第1维度对应的索引放index的值,即0,得到:
self[?][0];
其他维度的索引放index对应的索引,(dim=1有索引了,只差dim=0的索引),得到:
self[0][0];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][0],得到:
self[0][0]=src[0][0]=1;
...
index第三个元素:
dim=1,在self第1维度对应的索引放index的值,即2,得到:
self[?][2];
其他维度的索引放index对应的索引,(dim=1有索引了,只差dim=0的索引,该索引为0,因为在index第0行),得到:
self[0][1];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][1],得到:
self[0][1]=src[0][1]=2;
...
index第六个元素:
dim=1,在self第1维度对应的索引放index的值,即4,得到:
self[0][4];
其他维度的索引放index对应的索引,(dim=1有索引了,只差dim=0的索引,该索引为1, 因为在index第1行),得到:
self[1][4];
赋予的值是 src各维度放index对应的索引,index第四个元素所在位置为[0][4],得到:
self[1][4]=src[1][2]=8;
'''
题外话:脑子有点木,看半天绕不过来,感谢lhb大神的讲解。