1. 使用pytorch的prune工具进行剪枝
使用pytorch自带的prune函数进行剪枝,剪枝后被剪掉的参数为0,应将为0的参数剔除运算,否则为虚假的剪枝,速度甚至更慢
参考代码如下,model为训练后的模型,经循环得到剪枝模型,需微调回复精度,感觉没太大作用,不做具体演示。
from torch.nn.utils import prune
# 使用named_modules可以得到每一个最小层,使用named_children仅能得到较大的块
for n,module in model.named_modules():
# 对线性层剪枝
if isinstance(module,torch.nn.Linear):
# 可以选择多种裁剪方式,此处选择了随机裁剪;
# 其中name代表是对哪个参数进行裁剪,如果对偏置进行裁剪则该参数值为'bias';
# amount是指裁剪比例
prune.random_unstructured(module,name = 'weight', amount = 0.3)
# 此时model.weight被替换为model.weight_orig和model.weight_mask
# 使用list(module.named_buffers())可以查看
prune.remove(module,'weight')
2. 使用微软的nni工具进行剪枝
需要安装nni库,从nni.algorithms.compression.pytorch.pruning中选择想要的剪枝方法
演示demo: