Bootstrap

pytorch保存模型的坑

pytorch中保存模型相关的函数有3个:

  • torch.save:利用python的pickle模块实现序列化并保存序列化后的object
  • torch.load:利用pickle将保存的object反序列化
  • torch.nn.Module.load_state_dict:通过反序列化得到的state_dict读取保存的训练参数

有两种方法保存模型:

1. torch.save(model, path) # 直接保存整个模型
2. torch.save(model.state_dict(), path) # 保存模型的参数

相应地有两种方法加载保存的模型:

1. model = torch.load(path) # 直接加载模型
2. model = Model()                         # 先初始化一个模型
   model.load_state_dict(torch.load(path)) # 再加载模型参数

看起来第一种方法更加简单

;