Bootstrap

知识蒸馏 示例代码实现及下载

知识蒸馏 代码实现

论文《Distilling the Knowledge in a Neural Network》

* 源码以Github为准

Github链接:https://github.com/yeqiwang/KnowledgeDistilling

1. 数据集

本文使用fashion_mnist数据集,输入图像大小为28*28,共分为10类。

通过tensoflow加载数据,并对label进行one hot编码。

import tensorflow as tf
from tensorflow import keras
import numpy as np

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images/255
test_images = test_images/255
train_labels = tf.one_hot(train_labels, depth=10)
test_labels = tf.one_hot(test_labels, depth=10)

2. 教师模型

本文中使用一个4层MLP来作为教师模型。

训练过程中,模型最后使用softmax层来计算损失值。

训练结束后,更改最后的softmax层,以便生成软标签,其中T=2。同时,为了防止误操作,将教师模型冻结。

需要注意的是,虽然更改后教师模型不再进行训练,但仍需要使用compile函数进行配置,否则无法调用predict函数。

# 构建并训练教师模型
inputs = keras.layers.Input(shape=(28,28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers
;