Bootstrap

pytorch(10)-- 知识蒸馏

一、前言

        本篇讨论知识蒸馏,简单而言就是使用一个复杂但性能较好的模型作为教师模型,带动一个简单结构的学生模型迭代训练,使学生模型的数据拟合结果向教师模型结果趋近,从而提高简单模型的效果

        代码将使用cifa10数据集上训练达到95%准确率的模型作为教师模型,对一个简单3层卷积网络做蒸馏

二、代码

        教师模型采用 pytorch官方导出,加载训练好的模型

teach_model = resnet18(pretrained=False)
inchannel = teach_model.fc.in_features
teach_model.fc = nn.Linear(inchannel, 10)
teach_model.load_state_dict( torch.load( "cnn_resnet18.pth" , map_location = torch.device('cpu') )  )

        学生模型采用自建3层卷积神经网络

class simpleNet(nn.Module):
    def __init__(self, num_classes, input_nc):
        super(simpleNet, self).__init__()
        self.layer1 = nn.Sequential(       
            nn.Conv2d(input_nc, 16, 5, 2, 2,  bias=False),  
            nn.BatchNorm2d(16),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(7),   
        )       
        
        self.layer2 = nn.Sequential(       
            nn.Conv2d(16, 32, 3, 1, 1,  bias=False),  #  
            nn.BatchNorm2d(32),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(4),   
        )           
        
        self.layer3 = nn.Sequential(       
            nn.Conv2d( 32, 64, 3, 1, 1,  bias=False),  #  
            nn.ReLU(inplace = True),
            nn.MaxPool2d(2),   
        )  

        self.dropout = nn.Dropout( p=0.5 )   
        self.fc  = nn.Linear(2*2*64,10)    
        self.out = nn.Linear(10, num_classes)         
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)  
        x = self.layer3(x)  
      
        x = x.view(-1, 2*2*64 )
        x = self.dropout(x)
        x = self.fc(x)
        x = self.out(x)
        return x  

        训练代码,关键是调整好学习参数 alpha 和T

# distill.py 
#
#首先导入模块、准备数据
import torch
from torch.utils.data import DataLoader
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
import os
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from trainVal import trainset,testset
from model import   simpleNet
 

criterion = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()

from torchvision.models.resnet import resnet18

def train(  teach_model , model , device, train_loader, optimizer, epoch , scheduler ):
    loss_sigma = 0.0
    correct = 0.0
    total = 0.0
    alpha = 0.95    #参数可从0.5,0.9,0.95做尝试
    
        
    for i, data in enumerate(train_loader):
        inputs, labels = data
            
        inputs = inputs.to(device)
        labels = labels.to(device)   
            
        labels = labels.squeeze().long()
            
        # 梯度清零
        optimizer.zero_grad()    
            
        outputs = model(inputs.float())
        
        loss1 = criterion(outputs, labels)
        
        teacher_outputs = teach_model(inputs.float())
        T = 20  #  #参数可从2 ,10 , 20 做尝试 
        outputs_S = F.log_softmax(outputs/T,dim=1)
        outputs_T = F.softmax(teacher_outputs/T,dim=1)
        loss2 = criterion2(outputs_S,outputs_T)*T*T
        
        loss = loss1*(1-alpha) + loss2*alpha
            
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs.data, dim = 1)
        total += labels.size(0)
        correct += (predicted.cpu()==labels.cpu()).squeeze().sum().numpy()
        loss_sigma += loss.item()
        
        
        if i% 100 == 0:
            loss_avg = loss_sigma/10
            loss_sigma = 0.0
            print('loss_avg:{:.2}   '.format(loss_avg   ))
            print("Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}".format(
                    epoch, i * len(data), len(train_loader.dataset), 
                    100. * i / len(train_loader), loss.item()
                    ))
                
    scheduler.step()    
    
def test(  net, device, test_loader,train_loader, epoch ):
    net.eval()  #用到Batch Normalization 和 Dropout 就要加上
    test_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss()  # nn的函数是要先创建,后初始化
    with torch.no_grad():
        for data, label in test_loader: #不会做反向求导
            data, label = data.to(device), label.to(device)
            output = net(data) 
            
            test_loss +=  F.cross_entropy(output, label, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(label.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))


def main():
    # 先来判断是否要用cuda,默认是有的话就用
    use_cuda =   torch.cuda.is_available()
    torch.manual_seed( 123) #阈值随机设置
    device = torch.device("cuda" if use_cuda else "cpu")
    
 
    
    teach_model = resnet18(pretrained=False)
    inchannel = teach_model.fc.in_features
    teach_model.fc = nn.Linear(inchannel, 10)
    teach_model.load_state_dict( torch.load( "cnn_resnet18.pth" , map_location = torch.device('cpu') )  )
    teach_model = teach_model.to(device)
    
    #建立student模型  以及初始化超参数
    model = simpleNet( 10 , 3 ).to(device)

    optimizer = optim.Adam(model.parameters(),lr = 0.002)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30 , gamma=0.5)

    correct_ratio = []
 
   
    epoches = 100
    train_batchsize = 64
    test_batchsize = 64
    
    #准备数据加载器
    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
    
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=  train_batchsize  , shuffle=True,  **kwargs )
    test_loader = torch.utils.data.DataLoader(testset, batch_size= test_batchsize , shuffle=False,  **kwargs ) 
 
    #进行训练
    for epoch in range(epoches):
        train(  teach_model , model , device, train_loader, optimizer, epoch , scheduler )
        test(  model, device, test_loader,train_loader, epoch  ) 
       
if __name__ =="__main__":
    main()

蒸馏结果

        单独训练简单网络准确率结果在75%左右,使用蒸馏可达到76+,还是具有些效果,只是感觉没那么明显,关键需要调整好参数

;