一、网络模型的保存与读取方式1
方法讲解
保存模型
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")
读取模型
import torch
model = torch.load("save_method1.pth")
print(model)
输出:
比较坑人的点
使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B
的方式来解决),这里举一个例子来说明
保存模型
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
#保存模型和参数
class Mary(nn.Module):
def __init__(self):
super(Mary,self).__init__()
self.model1 = nn.Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")
读取模型
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
model = torch.load("save_method1_question.pth")
print(model)
报错如下
说明我们还要把 Mary 这个框架复制到读取模型的.py文件中
重新更正后的读取模型代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class Mary(nn.Module):
def __init__(self):
super(Mary,self).__init__()
self.model1 = nn.Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
model = torch.load("save_method1_question.pth")
print(model)
或者
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary #这里仅举一个例子
model = torch.load("save_method1_question.pth")
print(model)
二、网络模型的保存与读取方式2
保存模型参数
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")
读取模型参数
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)