前面加载数据都是直接调用的pytorch封装好的datasets.MNIST或则datasets.FashionMNIST等,学习阶段使用这样形式的数据没有什么影响,但是实际情况是我们要处理自己的数据。本节主要使用torch.utils.data.Dataset读取MNIST数据的csv文件
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import pandas as pd
import numpy as np
from PIL import Image
class DatasetFromCSV(Dataset):
def __init__(self,path,h,w,train=None,transforms=None):
self.data = pd.read_csv(path)
self.labels = np.array(self.data.iloc[:,0])
self.height = h
self.width = w
self.train = train
self.transforms = transforms
def __getitem__(self,idx):
if self.train is not False:
single_image_label = self.labels[idx]
img_as_np = np.asarray(self.data.iloc[idx][1:]).reshape(28,28).astype(float)
img_as_img = Image.fromarray(img_as_np)
img_as_