1、下载
可以使用
#train_set = mnist.MNIST('./data', train=True, download=True)
但是速度慢一般无法下载,官网下载也较慢
使用方法
└── MNIST
├── processed
│ ├── test.pt
│ └── training.pt
└── raw
├── t10k-images-idx3-ubyte
├── t10k-labels-idx1-ubyte
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
将压缩包放在mnist/row,文件夹下,运行train_set = mnist.MNIST('./data', train=True, download=True)即可解压,之后改为download=False即可调用
2、训练代码
读取数据集后,使用dataloader进行分组载入,利用pytorch构建网络,使用梯度下降法训练,对数据集训练了20次
import numpy as np
import torch
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
#train_set = mnist.MNIST('./data', train=True, download=False)