PyTorch学习笔记:data.BatchSampler——修改batch的封装策略
标准函数
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
功能:包装输入的采样器,从而产生小批量数据(mini-batch)
输入:
sampler
:基础的采样器,可以是任何迭代对象,数据类型为sampler 或者iterablebatch_size
:batch大小,数据类型为intdrop_last
:设为True
时,如果迭代到最后剩余的样本数不足以构成一个batch,则会丢弃最后的数据,否则不丢弃,数据类型为bool
代码案例
一般用法
from torch.utils.data.sampler import BatchSampler
batch = BatchSampler(range(10), batch_size=3, drop_last=False)
for i in batch:
print(i)
输出
[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
# 如果drop_last设为True,则不会输出最后一个9
[9]
自定义BatchSampler
如果在训练过程中对封装的batch有额外的需求时(如Faster RCNN,在采样时图像高宽比例位于同一区间的需要被封装到一个batch里),可以通过定义一个新类实现,该类需要继承BatchSampler,主要通过修改迭代方法__iter__(self)来实现。
代码案例
假设现有十个数据,并且每个数据都对应一个属性data
,现在要求属性是否可以整除2来封装batch,即对应属性可以整除2的数据被分到一个batch里,对应属性不能整除2的数据被分到另一个batch里,首先自定义类:
from torch.utils.data.sampler import BatchSampler
class GroupedBatchSampler(BatchSampler):
def __init__(self, sampler, batch_size, data):
# 之前定义的采样器
self.sampler = sampler
# batch size
self.batch_size = batch_size
# data可以表示数据对应的属性
self.data = data
def __iter__(self):
# 用于储存可以整除2的数据索引
group_div = []
# 用于储存不能整除2的数据索引
group = []
# 按采样器,遍历数据,相当于做一个采样
for idx in self.sampler:
# 得到被采样的数据属性
data = self.data[idx]
# 如果当前索引对应数据的属性可以整除2,则在group_div中添加该索引
if data % 2 == 0:
group_div.append(idx)
# 否则在group中添加该索引
else:
group.append(idx)
# 如果遍历结果满足一个batch了
# 则利用yield封装成一个迭代对象
# 并且返回该对象,执行一系列操作
if len(group_div) == self.batch_size:
yield group_div
# 执行完之后再回到该for循环,此时初始化索引集合group_div
group_div = []
# 这里同理
elif len(group) == self.batch_size:
yield group
group = []
注意:
- 在遍历采样器
sampler
的过程中,如果所存储的数据索引量已经满足一个batch,则需要利用yield
方法将其封装成一个迭代对象 - 一般该方法会与
for
循环相结合,通过for
循环指令的驱动来执行采样、封装batch的操作。因此当该类封装完一个迭代对象时,程序会暂时停止采样,即停止对self.sampler
的遍历,执行主程序中对迭代对象的操作,执行完之后再回来继续采样遍历。比如:在遍历数据集的时候,当采样得到的数据满足一个batch后,程序会暂停采样,之后将该批次数据打包,执行后续的一系列操作,如:训练、测试、亦或者是下面的print(i)
,执行完之后再回来继续采样遍历,直到遍历结束。
遍历采样器操作
`from torch.utils.data.sampler import RandomSampler
# 定义随机采样器
sampler = RandomSampler(range(80))
# 定义数据对应的属性,属性与索引值一致
# 这里属性是随便定义的,具体可以根据任务要求来定义的
data = [i for i in range(80)]
# 将得到批量采样器,按特定的功能进行采样
batch_sampler = GroupedBatchSampler(sampler, 8, data)
# 循环遍历,相当于训练数据集时遍历数据集进行训练
for i in batch_sampler:
print(i)
输出
[70, 6, 22, 36, 12, 40, 42, 64]
[7, 77, 47, 9, 57, 49, 61, 35]
[53, 31, 67, 45, 21, 37, 59, 63]
[76, 4, 18, 72, 44, 58, 32, 14]
[11, 79, 55, 15, 75, 25, 23, 71]
[8, 74, 34, 38, 20, 52, 30, 24]
[29, 19, 39, 5, 17, 41, 69, 1]
[50, 60, 56, 78, 10, 28, 2, 16]
[13, 27, 43, 3, 73, 33, 65, 51]
[0, 48, 68, 66, 26, 54, 46, 62]
官方文档
data.BatchSampler:https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler