pytorch中保存模型相关的函数有3个:
torch.save
:利用python的pickle模块实现序列化并保存序列化后的objecttorch.load
:利用pickle将保存的object反序列化torch.nn.Module.load_state_dict
:通过反序列化得到的state_dict
读取保存的训练参数
有两种方法保存模型:
1. torch.save(model, path) # 直接保存整个模型
2. torch.save(model.state_dict(), path) # 保存模型的参数
相应地有两种方法加载保存的模型:
1. model = torch.load(path) # 直接加载模型
2. model = Model() # 先初始化一个模型
model.load_state_dict(torch.load(path)) # 再加载模型参数
看起来第一种方法更加简单