Bootstrap

记录nvidia官方对Pix2PixHD的源码实现细节

好记性不如烂笔头,这篇论文我认为主要还是在feature matching的loss上,不过既然看了论文,就认真看一下源码吧

数据读取部分

经过一系列的追踪,我们可以在AlignedDataset中看到数据读取的方式,init里面都是些路径的设置,直接看get_item部分。

 A = Image.open(A_path)        
 params = get_params(self.opt, A.size)  #  {'crop_pos': (x, y), 'flip': flip}

在这里我们看到,先读取一张图像,然后调用get_params函数,这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。

transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0   # 0-255  

对数据预处理,有经过to_tensor操作,再乘255,所以输入范围是0-255,
接着读入真实图像

   if self.opt.isTrain or self.opt.use_encoded_image:
            B_path = self.B_paths[index]   
            B = Image.open(B_path).convert('RGB')
            transform_B = get_transform(self.opt, params)      
            B_tensor = transform_B(B)  # 经过归一化了  -1 ,1

接着读入instance,后续还会处理成边缘图,和论文中描述一致。

 if not self.opt.no_instance:
            inst_path = self.inst_paths[index]
            inst = Image.open(inst_path)
            inst_tensor = transform_A(inst)  # 和semantic的处理方式一样  0-1
            
  			if self.opt.load_features:
	                feat_path = self.feat_paths[index]            
	                feat = Image.open(feat_path).convert('RGB')
	                norm = normalize()
	                feat_tensor = norm(transform_A(feat))         

注意self.opt.load_features的作用是是否读取每个类别的预先计算的特征,论文中有10类,由聚类形成的。但默认是不执行的。我本人看论文对这一部分也是一知半解,以后有需求之后再研究。

 input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 
                      'feat': feat_tensor, 'path': A_path}

 return input_dict

之后就返回一个字典,记录了上述读取并经过处理的数据。
回到 train.py,看到model = create_model(opt),跟进去看看模型结构。

模型结构

又是经过一系列追踪,看到Pix2PixHDModel的类。这个类的内容非常多,有搭建模型,定义优化器和损失函数,导入模型等操作。
在initialize函数里面看看对Pix2PixHDModel的一些设置。

2.1 define_G

define_G函数里面主要内容就是这几行,我们知道在Pix2PixHD中,G是有两部分的,一部分是global net,另一部分是local net,就是前两个if语句对应的分支。第三个if语句对应的是论文中E的部分,用来预先计算类别特征,实现同样的semantic label的多样性输出。

  if netG == 'global':    
        netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)       
    elif netG == 'local':        
        netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, 
                                  n_local_enhancers, n_blocks_local, norm_layer)
    elif netG == 'encoder':
        netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)

接着我们进入看看GlobalGenerator的结构,就不看LocalEnhancer了。
先定义第一层,用的不是zero_padding,而是Reflection padding。因为第一层用的是7x7的卷积核,不想改变图的分辨率,就要用padding=3

  model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]

之后就是下采样环节,每一层卷积的stride都是2

 for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(ngf * mult * 2), activation]

之后就是残差块,残差块不改变分辨率。和论文描述一致。

 for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]

之后就是和下采样数目一样的上采样部分,上采样部分不像Unet结构,没有用到下采样得到的特征图。

 for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                       norm_layer(int(ngf * mult / 2)), activation]

之后就是模型的输出层。注意没有使用BN。

  model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]        

2.2 define_D

按照论文的说法,Pix2PixHD的D是有多个D,对于生成1024x2048的图像,需要使用3个D,对于生成512x1024的图像,需要使用两个D就行了。多个D输入不同比例的图像。
在define_D中,主要内容是下面这行。

  netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)   

我们再跳进 MultiscaleDiscriminator看看里面是啥。
在这个类的init函数里面,我们发现

  for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
            if getIntermFeat:           #忽略这条分支             
                for j in range(n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)

生成的NLayerDiscriminator类,被设置为当前类(self)的一个属性。生成num_D个D。
跳到NLayerDiscriminator看看。

    sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]

        nf = ndf
        for n in range(1, n_layers):
            nf_prev = nf
            nf = min(nf * 2, 512)
            sequence += [[
                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
                norm_layer(nf), nn.LeakyReLU(0.2, True)
            ]]

        nf_prev = nf
        nf = min(nf * 2, 512)
        sequence += [[
            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
            norm_layer(nf),
            nn.LeakyReLU(0.2, True)
        ]]

        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]

        if use_sigmoid:
            sequence += [[nn.Sigmoid()]]

就是普通的encoder网络了。注意use_sigmiod是False,即D的输出不使用sigmiod

至此,G和D就定义完成了,至于E,不是主要内容,就不多说了。

2.3 损失函数

self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) 
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:             
         self.criterionVGG = networks.VGGLoss(self.gpu_ids)
  • self.criterionGAN 是训练G和D的损失函数定义
  • self.criterionFeat 是feature matching损失项的定义,使用的是L1 loss。
  • self.criterionVGG 是percetual loss的定义。这是可选项,对最终结果也有帮助。

我们进入GANLoss看看如何定义的G和D的损失。
在构造函数中我们可以看出,如果没有在网络最后使用sigmoid,那就使用MSE loss

     if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

2.4 优化器

 self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            
 params = list(self.netD.parameters())    
 self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

前向传播

 losses, generated = model(Variable(data['label']), Variable(data['inst']), 
            Variable(data['image']), Variable(data['feat']), infer=save_fake)
  input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

这是model.forward的第一句,目的是做一下预处理,得到one hot编码,同时将instance map转换为edge,转换方法是对比四领域之间的差异,和论文方法一致。

G的前向传播没啥好说的,主要看下D的前向传播过程。

  for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
            else:
                model = getattr(self, 'layer'+str(num_D-1-i))
            result.append(self.singleD_forward(model, input_downsampled)) # 两个结果大小是2倍关系
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)

调用每一个之前定义的D,对每个D输入分辨率不一样的数据,这里的数据是图像,label的concat形式。

  loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

  # Real Detection and Loss        
  pred_real = self.discriminate(input_label, real_image)
  loss_D_real = self.criterionGAN(pred_real, True)

  # GAN loss (Fake Passability Loss)        
  pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
  loss_G_GAN = self.criterionGAN(pred_fake, True)               

G和D的损失,G的损失只有一项,就是loss_G_GAN ,D的损失有两项,loss_D_real 和loss_D_fake 。
feature matching的损失的计算方式,就是计算相同的位置上,假样本和真样本的特征的L1距离。

 loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
   optimizer_G.zero_grad()
   if opt.fp16:                                
       with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward()                
   else:
       loss_G.backward()          
   optimizer_G.step()

   # update discriminator weights
   optimizer_D.zero_grad()
   if opt.fp16:                                
       with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward()                
   else:
       loss_D.backward()        
   optimizer_D.step()        

按照GAN的交替训练方法训练就对了。有个地方我很奇怪,计算VGG loss的时候,按理说更新完G的参数,也要清除一下VGG网络计算得到的梯度才对,但是我找了整个工程,也没有见到清楚vgg网络梯度的语句。
还有计算VGG loss的是,对于真实样本的vgg特征,要detach一下。很重要

loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())    # 这个detach非常有必要,避免了多余的梯度累加。
;