Bootstrap

PyTorch学习笔记:data.BatchSampler——修改batch的封装策略

PyTorch学习笔记:data.BatchSampler——修改batch的封装策略

标准函数

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

功能:包装输入的采样器,从而产生小批量数据(mini-batch)

输入:

  • sampler:基础的采样器,可以是任何迭代对象,数据类型为sampler 或者iterable
  • batch_size:batch大小,数据类型为int
  • drop_last:设为True时,如果迭代到最后剩余的样本数不足以构成一个batch,则会丢弃最后的数据,否则不丢弃,数据类型为bool

代码案例

一般用法

from torch.utils.data.sampler import BatchSampler

batch = BatchSampler(range(10), batch_size=3, drop_last=False)
for i in batch:
    print(i)

输出

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
# 如果drop_last设为True,则不会输出最后一个9
[9]

自定义BatchSampler

  如果在训练过程中对封装的batch有额外的需求时(如Faster RCNN,在采样时图像高宽比例位于同一区间的需要被封装到一个batch里),可以通过定义一个新类实现,该类需要继承BatchSampler,主要通过修改迭代方法__iter__(self)来实现。

代码案例

  假设现有十个数据,并且每个数据都对应一个属性data,现在要求属性是否可以整除2来封装batch,即对应属性可以整除2的数据被分到一个batch里,对应属性不能整除2的数据被分到另一个batch里,首先自定义类:

from torch.utils.data.sampler import BatchSampler

class GroupedBatchSampler(BatchSampler):
    def __init__(self, sampler, batch_size, data):
        # 之前定义的采样器
        self.sampler = sampler
        # batch size
        self.batch_size = batch_size
        # data可以表示数据对应的属性
        self.data = data
    def __iter__(self):
        # 用于储存可以整除2的数据索引
        group_div = []
        # 用于储存不能整除2的数据索引
        group = []
        # 按采样器,遍历数据,相当于做一个采样
        for idx in self.sampler:
            # 得到被采样的数据属性
            data = self.data[idx]
            # 如果当前索引对应数据的属性可以整除2,则在group_div中添加该索引
            if data % 2 == 0:
                group_div.append(idx)
            # 否则在group中添加该索引
            else:
                group.append(idx)
            # 如果遍历结果满足一个batch了
            # 则利用yield封装成一个迭代对象
            # 并且返回该对象,执行一系列操作
            if len(group_div) == self.batch_size:
                yield group_div
                # 执行完之后再回到该for循环,此时初始化索引集合group_div
                group_div = []
            # 这里同理
            elif len(group) == self.batch_size:
                yield group
                group = []

注意:

  • 在遍历采样器sampler的过程中,如果所存储的数据索引量已经满足一个batch,则需要利用yield方法将其封装成一个迭代对象
  • 一般该方法会与for循环相结合,通过for循环指令的驱动来执行采样、封装batch的操作。因此当该类封装完一个迭代对象时,程序会暂时停止采样,即停止对self.sampler的遍历,执行主程序中对迭代对象的操作,执行完之后再回来继续采样遍历。比如:在遍历数据集的时候,当采样得到的数据满足一个batch后,程序会暂停采样,之后将该批次数据打包,执行后续的一系列操作,如:训练、测试、亦或者是下面的print(i),执行完之后再回来继续采样遍历,直到遍历结束。

遍历采样器操作

`from torch.utils.data.sampler import RandomSampler

# 定义随机采样器
sampler = RandomSampler(range(80))
# 定义数据对应的属性,属性与索引值一致
# 这里属性是随便定义的,具体可以根据任务要求来定义的
data = [i for i in range(80)]
# 将得到批量采样器,按特定的功能进行采样
batch_sampler = GroupedBatchSampler(sampler, 8, data)

# 循环遍历,相当于训练数据集时遍历数据集进行训练
for i in batch_sampler:
    print(i)

输出

[70, 6, 22, 36, 12, 40, 42, 64]
[7, 77, 47, 9, 57, 49, 61, 35]
[53, 31, 67, 45, 21, 37, 59, 63]
[76, 4, 18, 72, 44, 58, 32, 14]
[11, 79, 55, 15, 75, 25, 23, 71]
[8, 74, 34, 38, 20, 52, 30, 24]
[29, 19, 39, 5, 17, 41, 69, 1]
[50, 60, 56, 78, 10, 28, 2, 16]
[13, 27, 43, 3, 73, 33, 65, 51]
[0, 48, 68, 66, 26, 54, 46, 62]

官方文档

data.BatchSampler:https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler

;