Bootstrap

Pytorch入门学习(九)---detach()的作用(从GAN代码分析)

(八)还没写,先跳过。。。

总说

简单来说detach就是截断反向传播的梯度流

    def detach(self):
        """Returns a new Variable, detached from the current graph.

        Result will never require gradient. If the input is volatile, the output
        will be volatile too.

        .. note::

          Returned Variable uses the same data tensor, as the original one, and
          in-place modifications on either of them will be seen, and may trigger
          errors in correctness checks.
        """
        result = NoGrad()(self)  # this is needed, because it merges version counters
        result._grad_fn = None
        return result

可以看到Returns a new Variable, detached from the current graph。将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。

从GAN的代码中看detach()

GAN的G的更新,主要是GAN loss。就是G生成的fake图让D来判别,得到的损失,计算梯度进行反传。这个梯度只能影响G,不能影响D!可以看到,由于torch是非自动求导的,每一层的梯度的计算必须用net:backward才能计算gradInput和网络中的参数的梯度。

先看Torch版本的代码

local fGx = function(x)
    netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
    netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)

    gradParametersG:zero()

    -- GAN loss
    local df_dg = torch.zeros(fake_B:size())
    if opt.use_GAN==1 then
       <
;