import matplotlib.pyplot as plt
num_images=50# 行数,列数,全局大小(长,宽)
fig, axes = plt.subplots(num_images,2, figsize=(10,5*num_images))for i inrange(num_images):
axes[i][0].imshow(image)
axes[i][1].imshow(enhanced_image)
axes[i][0].axis('off')
axes[i][1].axis('off')
plt.show()
保存模型恢复训练
import torch
torch.save(model.state_dict(),'model_checkpoint.pth')# later
model = Model()
model.load_state_dict(torch.load('model_checkpoint.pth'))
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)# continue to train