Bootstrap

python调用数据集mnist_MNIST数据集下载训练测试Python,mnist,pytorch

1、下载

可以使用

#train_set = mnist.MNIST('./data', train=True, download=True)

但是速度慢一般无法下载,官网下载也较慢

使用方法

└── MNIST

├── processed

│   ├── test.pt

│   └── training.pt

└── raw

├── t10k-images-idx3-ubyte

├── t10k-labels-idx1-ubyte

├── train-images-idx3-ubyte

└── train-labels-idx1-ubyte

将压缩包放在mnist/row,文件夹下,运行train_set = mnist.MNIST('./data', train=True, download=True)即可解压,之后改为download=False即可调用

2、训练代码

读取数据集后,使用dataloader进行分组载入,利用pytorch构建网络,使用梯度下降法训练,对数据集训练了20次

import numpy as np

import torch

from torchvision.datasets import mnist

from torch import nn

from torch.autograd import Variable

#train_set = mnist.MNIST('./data', train=True, download=False)

;