Bootstrap

模型压缩(2)——模型剪枝

1. 使用pytorch的prune工具进行剪枝

使用pytorch自带的prune函数进行剪枝,剪枝后被剪掉的参数为0,应将为0的参数剔除运算,否则为虚假的剪枝,速度甚至更慢

参考代码如下,model为训练后的模型,经循环得到剪枝模型,需微调回复精度,感觉没太大作用,不做具体演示。

from torch.nn.utils import prune
# 使用named_modules可以得到每一个最小层,使用named_children仅能得到较大的块
for n,module in model.named_modules():
    # 对线性层剪枝
    if isinstance(module,torch.nn.Linear):
        # 可以选择多种裁剪方式,此处选择了随机裁剪;
        # 其中name代表是对哪个参数进行裁剪,如果对偏置进行裁剪则该参数值为'bias';
        # amount是指裁剪比例
        prune.random_unstructured(module,name = 'weight', amount = 0.3)
        # 此时model.weight被替换为model.weight_orig和model.weight_mask 
        # 使用list(module.named_buffers())可以查看
        prune.remove(module,'weight')

2. 使用微软的nni工具进行剪枝

需要安装nni库,从nni.algorithms.compression.pytorch.pruning中选择想要的剪枝方法

教程链接:Pruning — An open source AutoML toolkit for neural architecture search, model compression and hyper-parameter tuning (NNI v2.6.1)https://nni.readthedocs.io/en/stable/Compression/pruning.html

演示demo:

;