Bootstrap

大模型 - 知识蒸馏原理解析

知识蒸馏的详细过程和原理解析

知识蒸馏是一种通过将大型预训练模型(教师模型)的知识传递给较小模型(学生模型)的方法。这样可以在减少模型的复杂度和计算资源需求的同时,尽量保留模型的性能。以下是知识蒸馏的详细过程和每个步骤中用到的原理。

1. 输入数据

假设我们有一个图像分类任务,输入数据 x x x 是一张图像。这个图像同时馈送给教师模型和学生模型。

2. 教师模型

  • 教师模型是一个已经训练好的大模型,它对输入 x x x 进行预测。
  • 教师模型的输出经过一个带温度参数 T T T 的 softmax 函数,得到软标签(soft labels)。温度参数 T T T 用于平滑预测概率,使得输出概率分布更平缓。

具体来说,假设教师模型输出的 logits 为 [ 2.0 , 1.0 , 0.1 ] [2.0, 1.0, 0.1] [2.0,1.0,0.1],在温度 T = 2 T=2 T=2 下,softmax 计算如下:

softmax ( z i ; T = 2 ) = e z i / 2 ∑ j e z j / 2 \text{softmax}(z_i; T=2) = \frac{e^{z_i / 2}}{\sum_{j} e^{z_j / 2}} softmax(zi;T=2)=jezj/2ezi/2

计算得:
softmax ( 2.0 / 2 ) = e 1.0 e 1.0 + e 0.5 + e 0.05 = 0.504 \text{softmax}(2.0 / 2) = \frac{e^{1.0}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.504 softmax(2.0/2)=e1.0+e0.5+e0.05e1.0=0.504
softmax ( 1.0 / 2 ) = e 0.5 e 1.0 + e 0.5 + e 0.05 = 0.277 \text{softmax}(1.0 / 2) = \frac{e^{0.5}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.277 softmax(1.0/2)=e1.0+e0.5+e0.05e0.5=0.277
softmax ( 0.1 / 2 ) = e 0.05 e 1.0 + e 0.5 + e 0.05 = 0.219 \text{softmax}(0.1 / 2) = \frac{e^{0.05}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.219 softmax(0.1/2)=e1.0+e0.5+e0.05e0.05=0.219

软标签为 [ 0.504 , 0.277 , 0.219 ] [0.504, 0.277, 0.219] [0.504,0.277,0.219]

3. 学生模型

  • 学生模型是一个较小的模型,它也对输入 x x x 进行预测。
  • 学生模型的输出经过两个 softmax 函数处理,一个带温度 T T T 得到软预测(soft predictions),另一个带温度 T = 1 T=1 T=1 得到硬预测(hard predictions)。

假设学生模型输出的 logits 为 [ 1.8 , 0.9 , 0.4 ] [1.8, 0.9, 0.4] [1.8,0.9,0.4],在温度 T = 2 T=2 T=2 下,softmax 计算如下:

softmax ( 1.8 / 2 ) = e 0.9 e 0.9 + e 0.45 + e 0.2 = 0.474 \text{softmax}(1.8 / 2) = \frac{e^{0.9}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.474 softmax(1.8/2)=e0.9+e0.45+e0.2e0.9=0.474
softmax ( 0.9 / 2 ) = e 0.45 e 0.9 + e 0.45 + e 0.2 = 0.301 \text{softmax}(0.9 / 2) = \frac{e^{0.45}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.301 softmax(0.9/2)=e0.9+e0.45+e0.2e0.45=0.301
softmax ( 0.4 / 2 ) = e 0.2 e 0.9 + e 0.45 + e 0.2 = 0.225 \text{softmax}(0.4 / 2) = \frac{e^{0.2}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.225 softmax(0.4/2)=e0.9+e0.45+e0.2e0.2=0.225

软预测为 [ 0.474 , 0.301 , 0.225 ] [0.474, 0.301, 0.225] [0.474,0.301,0.225]

硬预测( T = 1 T=1 T=1)的 softmax 计算如下:
softmax ( 1.8 ) = e 1.8 e 1.8 + e 0.9 + e 0.4 = 0.659 \text{softmax}(1.8) = \frac{e^{1.8}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.659 softmax(1.8)=e1.8+e0.9+e0.4e1.8=0.659
softmax ( 0.9 ) = e 0.9 e 1.8 + e 0.9 + e 0.4 = 0.242 \text{softmax}(0.9) = \frac{e^{0.9}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.242 softmax(0.9)=e1.8+e0.9+e0.4e0.9=0.242
softmax ( 0.4 ) = e 0.4 e 1.8 + e 0.9 + e 0.4 = 0.099 \text{softmax}(0.4) = \frac{e^{0.4}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.099 softmax(0.4)=e1.8+e0.9+e0.4e0.4=0.099

硬预测为 [ 0.659 , 0.242 , 0.099 ] [0.659, 0.242, 0.099] [0.659,0.242,0.099]

4. 蒸馏损失(Distillation Loss)

  • 蒸馏损失是教师模型的软标签和学生模型的软预测之间的差异,通常使用 KL 散度(Kullback-Leibler Divergence)作为损失函数。

D K L ( P ∥ Q ) = ∑ x ∈ X P ( x ) log ⁡ ( P ( x ) Q ( x ) ) D_{KL}(P \parallel Q) = \sum_{x \in X} P(x) \log \left( \frac{P(x)}{Q(x)} \right) DKL(PQ)=xXP(x)log(Q(x)P(x))

假设软标签 P P P [ 0.504 , 0.277 , 0.219 ] [0.504, 0.277, 0.219] [0.504,0.277,0.219],软预测 Q Q Q [ 0.474 , 0.301 , 0.225 ] [0.474, 0.301, 0.225] [0.474,0.301,0.225]
D K L ( P ∥ Q ) = 0.504 log ⁡ ( 0.504 0.474 ) + 0.277 log ⁡ ( 0.277 0.301 ) + 0.219 log ⁡ ( 0.219 0.225 ) D_{KL}(P \parallel Q) = 0.504 \log \left( \frac{0.504}{0.474} \right) + 0.277 \log \left( \frac{0.277}{0.301} \right) + 0.219 \log \left( \frac{0.219}{0.225} \right) DKL(PQ)=0.504log(0.4740.504)+0.277log(0.3010.277)+0.219log(0.2250.219)
计算得:
D K L ( P ∥ Q ) = 0.504 ⋅ 0.0623 + 0.277 ⋅ − 0.0848 + 0.219 ⋅ − 0.0267 D_{KL}(P \parallel Q) = 0.504 \cdot 0.0623 + 0.277 \cdot -0.0848 + 0.219 \cdot -0.0267 DKL(PQ)=0.5040.0623+0.2770.0848+0.2190.0267
= 0.0314 − 0.0235 − 0.0058 = 0.0314 - 0.0235 - 0.0058 =0.03140.02350.0058
= 0.0021 = 0.0021 =0.0021

5. 学生损失(Student Loss)

  • 学生损失是学生模型的硬预测和真实标签(硬标签)之间的差异,通常使用交叉熵损失函数。

假设真实标签 y y y 为类别 1,则 one-hot 编码为 [ 1 , 0 , 0 ] [1, 0, 0] [1,0,0],硬预测为 [ 0.659 , 0.242 , 0.099 ] [0.659, 0.242, 0.099] [0.659,0.242,0.099],交叉熵损失为:

H ( y , y ^ ) = − ∑ i y i log ⁡ ( y ^ i ) H(y, \hat{y}) = - \sum_{i} y_i \log(\hat{y}_i) H(y,y^)=iyilog(y^i)
H ( y , y ^ ) = − ( 1 ⋅ log ⁡ ( 0.659 ) + 0 ⋅ log ⁡ ( 0.242 ) + 0 ⋅ log ⁡ ( 0.099 ) ) H(y, \hat{y}) = - (1 \cdot \log(0.659) + 0 \cdot \log(0.242) + 0 \cdot \log(0.099)) H(y,y^)=(1log(0.659)+0log(0.242)+0log(0.099))
= − log ⁡ ( 0.659 ) = 0.416 = - \log(0.659) = 0.416 =log(0.659)=0.416

6. 总损失(Total Loss)

  • 总损失是蒸馏损失和学生损失的加权和:
    Total Loss = α × Student Loss + β × Distillation Loss \text{Total Loss} = \alpha \times \text{Student Loss} + \beta \times \text{Distillation Loss} Total Loss=α×Student Loss+β×Distillation Loss

假设 α = 1 \alpha = 1 α=1 β = 0.5 \beta = 0.5 β=0.5,则总损失为:
Total Loss = 1 × 0.416 + 0.5 × 0.0021 = 0.417 \text{Total Loss} = 1 \times 0.416 + 0.5 \times 0.0021 = 0.417 Total Loss=1×0.416+0.5×0.0021=0.417

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义教师模型和学生模型
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc = nn.Linear(784, 10)
    
    def forward(self, x):
        return self.fc(x)

class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc = nn.Linear(784, 10)
    
    def forward(self, x):
        return self.fc(x)

# 定义蒸馏损失函数
def distillation_loss(soft_labels, soft_predictions, T):
    soft_labels = F.softmax(soft_labels / T, dim=1)
    soft_predictions = F.log_softmax(soft_predictions / T, dim=1)
    loss = F.kl_div(soft_predictions, soft_labels, reduction='batchmean') * (T ** 2)
    return loss

# 定义学生损失函数
def student_loss(hard_labels, hard_predictions):
    return F.cross_entropy(hard_predictions, hard_labels)

# 超参数
alpha = 1.0
beta = 0.5
temperature = 2.0
learning_rate = 0.001
num_epochs = 10

# 数据加载器(使用MNIST数据集作为示例)
from torchvision import datasets, transforms
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=64, shuffle=True)

# 初始化模型、优化器
teacher_model = TeacherModel()
student_model = StudentModel()
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

# 假设教师模型已经预训练好,这里直接加载预训练权重
# teacher_model.load_state_dict(torch.load('teacher_model.pth'))

# 训练过程
teacher_model.eval()  # 教师模型设为评估模式,不进行训练
student_model.train()  # 学生模型设为训练模式

for epoch in range(num_epochs):
    total_loss = 0
    for data, target in train_loader:
        data = data.view(data.size(0), -1)  # 展开图像数据
        
        # 教师模型预测
        with torch.no_grad():
            teacher_output = teacher_model(data)
        
        # 学生模型预测
        student_output = student_model(data)
        soft_predictions = student_output / temperature
        hard_predictions = student_output
        
        # 计算蒸馏损失和学生损失
        dist_loss = distillation_loss(teacher_output, student_output, temperature)
        stud_loss = student_loss(target, hard_predictions)
        
        # 计算总损失
        loss = alpha * stud_loss + beta * dist_loss
        
        # 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}')

# 保存学生模型
torch.save(student_model.state_dict(), 'student_model.pth')

代码解释

  1. 模型定义:定义了一个简单的全连接层的教师模型和学生模型。

  2. 蒸馏损失和学生损失函数

    • distillation_loss 计算KL散度作为蒸馏损失。
    • student_loss 计算交叉熵损失作为学生损失。
  3. 超参数

    • alphabeta 分别是学生损失和蒸馏损失的权重。
    • temperature 是温度参数,用于平滑教师模型的输出。
  4. 数据加载:使用MNIST数据集作为示例。

  5. 模型初始化:初始化教师模型和学生模型,并定义优化器。

  6. 训练过程

    • 教师模型设为评估模式,学生模型设为训练模式。
    • 在每个训练周期中,对每个批次数据进行预测,计算损失,并进行优化。
  7. 保存模型:在训练结束后保存学生模型的权重。

该代码示例展示了如何通过PyTorch实现模型蒸馏的训练过程。如果有其他需求或需要进一步解释的地方,请告诉我。

总结

知识蒸馏通过教师模型提供的软标签引导学生模型,使得学生模型不仅关注硬标签的分类准确性,还能从软标签中学习更丰富的类别间关系,从而在模型压缩的同时尽量保留性能。这种方法特别适用于在资源受限的环境中部署高效的深度学习模型。

;