Bootstrap

pytorch学习之路——05

前面加载数据都是直接调用的pytorch封装好的datasets.MNIST或则datasets.FashionMNIST等,学习阶段使用这样形式的数据没有什么影响,但是实际情况是我们要处理自己的数据。本节主要使用torch.utils.data.Dataset读取MNIST数据的csv文件

from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

import pandas as pd
import numpy as np
from PIL import Image

class DatasetFromCSV(Dataset):
    def __init__(self,path,h,w,train=None,transforms=None):
        self.data = pd.read_csv(path)
        self.labels = np.array(self.data.iloc[:,0])
        self.height = h
        self.width = w
        self.train = train
        self.transforms = transforms

    def __getitem__(self,idx):
        if self.train is not False:
            single_image_label = self.labels[idx]
            img_as_np = np.asarray(self.data.iloc[idx][1:]).reshape(28,28).astype(float)
            img_as_img = Image.fromarray(img_as_np)
            img_as_
;