(八)还没写,先跳过。。。
总说
简单来说detach就是截断反向传播的梯度流。
def detach(self):
"""Returns a new Variable, detached from the current graph.
Result will never require gradient. If the input is volatile, the output
will be volatile too.
.. note::
Returned Variable uses the same data tensor, as the original one, and
in-place modifications on either of them will be seen, and may trigger
errors in correctness checks.
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None
return result
可以看到Returns a new Variable, detached from the current graph。将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。
从GAN的代码中看detach()
GAN的G的更新,主要是GAN loss。就是G生成的fake图让D来判别,得到的损失,计算梯度进行反传。这个梯度只能影响G,不能影响D!可以看到,由于torch是非自动求导的,每一层的梯度的计算必须用net:backward
才能计算gradInput
和网络中的参数的梯度。
先看Torch版本的代码
local fGx = function(x)
netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
gradParametersG:zero()
-- GAN loss
local df_dg = torch.zeros(fake_B:size())
if opt.use_GAN==1 then
<