Bootstrap

对mnist数据集的训练

from torch.nn import CrossEntropyLoss
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=100, shuffle=False)


class MnistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(28 * 28 * 1, 1000),
            nn.ReLU(),
            nn.Linear(1000, 600),
            nn.ReLU(),
            nn.Linear(600, 300),
            nn.ReLU(),
            nn.Linear(300, 100),
            nn.ReLU(),
            nn.Linear(100, 10)
        )

    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        x = self.layer(x)
        return x


net = MnistNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)


def train(epoches=5):
    for epoch in range(epoches):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            logit = net(images)
            loss = loss_fn(logit, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoches:[{epoch + 1} / {epoches},Loss={total_loss}]")


def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            logit = net(images)
            max_value, max_index = torch.max(logit, dim=1)
            total += labels.size(0)
            correct += (max_index == labels).sum().item()
    print(f"Accurary:{100 * correct / total:.2f}%")


train(5)
test()

解释:

1. 首先使用torchvision中的datasets获取数据

2. 然后通过torch.utils.data中的dataloader对数据进行分批处理

3. 使用for循环迭代训练数据,得到单批次的数据,并对模型进行训练。打印每轮学习中的平均损失率

4. 使用for循环迭代测试数据,测试模型,再拿模型对结果的打分和真实值进行比较,求出正确率

;