Bootstrap

简单速成Pytorch的scatter_函数理解

首先明确,这个函数实现的功能是”放“
怎么个放法呢,看这个函数的参数:

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
  1. src:将src这个tensor中的值,到self里(也就是”.“符号前面的那个Tensor)。src不一定要是一个tensor,也可以是一个值。
  2. dim及index:指示要的具体位置。

self [index[i][j][k]] [j][k] = src[i][j][k] # if dim == 0
self[i] [index[i][j][k]] [k] = src[i][j][k] # if dim == 1
self[i][j] [index[i][j][k]] = src[i][j][k] # if dim == 2

  1. reduce则根据官方文档的陈述,这个,可以是替换(None),或者是加、乘到原先的值上。

reduce (str, optional) – reduction operation to apply, can be either ‘add’ or ‘multiply’.

不好理解哈,首先举个理解性的例子(考虑到什么方法来自pytorch库,大家对以下代码应该非常熟悉)

	#代码来源于znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN
    for x_,y_ in train_loader:
        #首先训练D
        D
;