交叉熵损失函数
loss = nn.CrossEntropyLoss()
优化器
optim = torch.optim.SGD(MyNet.parameters(),lr=0.01)
代码:
import torchvision
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
test_data = torchvision.datasets.CIFAR10(root="./test10_dataset", train