我们一般使用for循环来训练神经网络,在每次的迭代过程,从DataLoader中取出batchsize的数据,然后前向传播反向传播一次,更新参数一次
在加载batch数据的时候,torch创建一个可迭代的Dataset对象(需要重写__getitem__()和__len__()两个方法),然后与DataLoader一起使用;
DataLoader: 构造一个整数索引的采样器来获取Dataset的数据
Dataset
创建Dataset对象:
需要重写 getitem 方法和 len 方法。
前者通过提供索引返回数据,也就是提供 DataLoader获取数据的方式;后者返回数据集的长度,DataLoader依据 len 确定自身索引采样器的长度。
from torch.utils.data import Dataset
# 输入形式:[{'x':['token_id',..],'y':[label]},..],[(['token_id',..],[label]),..]</