Bootstrap

mmdetection自定义损失

MMDetection 为用户提供了不同的损失函数。但是默认配置可能不适用于不同的数据集或模型,因此用户可能希望修改特定的损失以适应新的情况。

本教程首先阐述了损失的计算管道,然后给出了如何修改每个步骤的说明。修改可分为调整和加权。

损失的计算管道
给定输入预测和目标以及权重,损失函数将输入张量映射到最终损失标量。映射可以分为四个步骤:

  1. 设置采样方法为对正负样本进行采样。

  2. 通过损失核函数获得逐元素或逐样本的损失。

  3. 使用权重张量元素加权损失。

  4. 将损失张量减少到一个标量。

  5. 用标量加权损失。

    设置采样方法(步骤 1)
    对于一些损失函数,需要采样策略来避免正负样本之间的不平衡。

比如CrossEntropyLoss在RPN head中使用时,我们需要RandomSampler在train_cfg

train_cfg=dict(
    rpn=dict(
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False))

对于其他一些具有正负样本平衡机制的损失,例如 Focal Loss、GHMC 和 QualityFocalLoss,不再需要采样器。

调整损失
调整损失与第 2、4、5 步更相关,大多数修改都可以在 config.xml 文件中指定。这里我们以Focal Loss (FL)为例。下面的代码狙击分别是FL的构造方法和配置,实际上是一一对应的。

@LOSSES.register_module()
class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0):
loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=1.0)

调整超参数(第 2 步)
gamma和beta是 Focal Loss 中的两个超参数。假设我们要将 的值更改gamma为 1.5 和alpha0.5,那么我们可以在配置中指定它们,如下所示:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=1.5,
    alpha=0.5,
    loss_weight=1.0)

调整还原方式(第 3 步)
默认的减少方式是meanFL。假设我们想将减少量从mean改为sum,我们可以在配置中指定如下:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=1.0,
    reduction='sum')

调整减重(第 5 步)
这里的损失权重是一个标量,用于控制多任务学习中不同损失的权重,例如分类损失和回归损失。假设我们想把classification loss的loss weight改为0.5,我们可以在config中指定如下:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=0.5)

减重(第 3 步)
加权损失意味着我们明智地重新加权损失元素。更具体地说,我们将损失张量与具有相同形状的权重张量相乘。因此,损失的不同条目可以进行不同的缩放,也就是所谓的逐元素缩放。损失权重因不同模型而异,并且与上下文高度相关,但总体而言有两种损失权重,label_weights分类损失和bbox_weightsbbox 回归损失。你可以get_target在对应头部的方法中找到它们。这里我们以ATSSHead为例,它继承了AnchorHead,但是覆盖了它的get_targets方法,产生不同的label_weights和bbox_weights。

class ATSSHead(AnchorHead):

    ...

    def get_targets(self,
                    anchor_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    img_metas,
                    gt_bboxes_ignore_list=None,
                    gt_labels_list=None,
                    label_channels=1,
                    unmap_outputs=True):

悦读

道可道,非常道;名可名,非常名。 无名,天地之始,有名,万物之母。 故常无欲,以观其妙,常有欲,以观其徼。 此两者,同出而异名,同谓之玄,玄之又玄,众妙之门。

;