Bootstrap

使用Sum计算Loss和解决梯度累积(Gradient Accumulation)的Bug

使用Sum计算Loss和解决梯度累积的Bug

学习 https://unsloth.ai/blog/gradient:Bugs in LLM Training - Gradient Accumulation Fix 这篇文章的记录。

在深度学习训练过程中,尤其是在大批量(large batch)训练中,如何高效地计算损失函数(Loss)并避免内存溢出一直是一个重要的问题。在许多情况下,为了实现更高效的训练,我们会使用梯度累积(Gradient Accumulation)策略,模拟大批量训练的效果,但同时减少对显存的需求。然而,这种方法可能引发一些数值计算错误,尤其是在损失计算时。因此,理解如何处理梯度累积并确保计算正确性就变得尤为重要。

为什么要使用Sum计算Loss?

在代码中,reduce_loss="sum"表示我们选择将每个小批量(mini-batch)的损失值相加而不是取平均。这种做法有几个目的:

  1. 权重每个样本的影响一致:在使用大批量训练时,如果直接取每个样本的平均损失,可能会使得样本数目更多的批次对模型训练产生不必要的影响。而通过将损失值求和,确保了每个样本都以相同的权重参与模型更新。

  2. 避免梯度累积带来的偏差:如果只使用小批量训练的平均损失进行梯度累积(梯度更新),随着梯度累积步骤(gradient accumulation step)增多,损失的计算会变得不准确。因为简单地将每个小批量的损失相加,可能会导致最终计算的损失过大。

梯度累积中的问题

当我们使用梯度累积时,目的是在显存限制下模拟大批量训练的效果。假设每个小批量的计算后,我们累积其梯度,最后再进行一次反向传播。这种方法可以显著减少显存的使用,但也带来了数值误差,特别是在损失的计算上。

梯度累积是否和完整批次训练数学上等价?

答案是: 不完全等价,尤其是如果我们不正确地处理梯度的累积。在数学上,损失函数通常表示为:

L = 1 n ∑ i = 1 n L i L = \frac{1}{n} \sum_{i=1}^{n} L_i L=n1i=1nLi

其中,( L i L_i Li ) 是每个样本的损失,( n n n ) 是样本总数。我们通常会对每个小批量的损失进行归一化处理,使得每个样本对损失的贡献相等。

但是,在梯度累积时,我们的处理方法是:

L accumulated = ∑ k = 1 G L k L_{\text{accumulated}} = \sum_{k=1}^{G} L_k Laccumulated=k=1GLk

其中,( G G G ) 是梯度累积的步数, ( L k L_k Lk ) 是每个小批量的损失值。这导致累积损失和原本应计算的损失不同,因为没有进行正确的归一化操作。

为什么梯度累积会导致问题?

由于梯度累积在计算损失时没有正确处理每个小批量的权重,最终的损失值可能比实际应有的损失大,特别是当批次大小不一致时。此时,我们需要对每个小批量的损失进行缩放,以使最终的结果与大批量训练的损失相符。

解决方案:修正梯度累积中的损失计算

为了解决梯度累积中损失计算的偏差,我们可以在每次计算梯度时对每个小批量的损失进行缩放。具体来说,假设我们在训练过程中设置了 ( G G G ) 步梯度累积,那么每次计算损失时,我们需要将每个小批量的损失除以 ( G G G ),以确保最终的梯度更新符合预期。
数学公式如下:

L final = 1 G ∑ k = 1 G L k m k L_{\text{final}} = \frac{1}{G} \sum_{k=1}^{G} \frac{L_k}{m_k} Lfinal=G1k=1GmkLk

其中,( m k m_k mk ) 是第 ( k k k ) 个小批量的有效样本数(即去除填充后的token数),( G G G ) 是梯度累积的步数。这样做的目的是在每次梯度累积时确保损失值的加权平均,而不是简单地将其相加。

小结

  1. Loss的求和:在使用 reduce_loss="sum" 时,我们通过求和的方式来计算损失,而不是简单的平均。这可以确保每个样本对损失计算的贡献一致,避免了梯度累积中的偏差。

  2. 梯度累积的误差:梯度累积在计算损失时容易出现数值误差,尤其是在不同小批量的损失和有效样本数不一致的情况下。为了解决这一问题,我们需要对每个小批量的损失进行缩放处理,使得最终的损失与大批量训练的结果一致。

通过正确处理梯度累积的损失计算,我们可以在不增加内存消耗的情况下模拟大批量训练的效果,从而提高训练效率。

附录:代码解析

下面是open instruct框架提供的sum计算 loss的代码,它实现了一个损失计算和梯度累积的处理流程。

if args.reduce_loss == "mean":
    loss = outputs.loss
else:
    # reduce loss is sum
    # this ensures that we weight all tokens in the dataset equally,
    # rather than weighting each overall example equally when
    # using high amounts of gradient accumulation.
    # this can result in > 5 point improvements in AlpacaEval
    # see https://github.com/huggingface/transformers/issues/24725 for
    # more discussion and details.
    logits = outputs.logits
    labels = batch["labels"]
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
    shift_logits = shift_logits.view(-1, embedding_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)
    if args.load_balancing_loss:
        aux_loss = args.load_balancing_weight * outputs.aux_loss
        loss += aux_loss
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
1. Loss的选择reduce_loss参数)

代码的第一部分是根据args.reduce_loss参数来决定使用平均(mean)还是求和(sum)的损失计算方式。如果选择的是mean,则直接使用outputs.loss(通常是模型自带的损失)。如果选择的是sum,则使用后续的处理方式。

2. 处理sum模式的损失计算

如果reduce_losssum,代码会进入求和模式,具体步骤如下:

  • 获取logits和labels:从模型的输出outputs中获取logits(预测值)和batch["labels"](标签)。

  • Shift操作:为了计算交叉熵损失,模型的输出logits会做一个移位操作,使得每个token的预测目标对应下一个token。这是因为语言模型任务中的预测是基于上一个token的。

    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    

    这里logits[..., :-1, :]表示除去最后一个token的logits,而labels[..., 1:]表示去掉第一个token的标签。这样处理的目的是确保每个token的预测值与下一个token的真实标签进行比较。

  • Flatten操作:对处理后的logitslabels进行flatten,展平为一维向量,以便交叉熵损失函数计算。

    shift_logits = shift_logits.view(-1, embedding_size)
    shift_labels = shift_labels.view(-1)
    

    这一步是为了将shift_logitsshift_labels转换为适合交叉熵损失函数的格式,即shift_logits变为 [batch_size * sequence_length, embedding_size] 维度,shift_labels变为 [batch_size * sequence_length] 维度。

  • 交叉熵损失函数:使用torch.nn.CrossEntropyLoss计算损失。这里的reduction="sum"表示对所有token的损失求和,而不是取平均。这样可以确保每个token的贡献相同。

    loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
    loss = loss_fct(shift_logits, shift_labels)
    

    这段代码计算了每个token的损失,并将所有token的损失求和。对于大批量训练,这样做是为了避免梯度累积时损失计算的不准确。

  • 负载平衡损失(Load Balancing Loss):如果args.load_balancing_loss为True,表示需要加上额外的负载平衡损失。负载平衡损失用于调整模型在不同任务之间的负载,以确保训练过程的稳定性。

    if args.load_balancing_loss:
        aux_loss = args.load_balancing_weight * outputs.aux_loss
        loss += aux_loss
    

    这部分代码将outputs.aux_loss(辅助损失)加权后加到主损失上,以增强训练的稳定性。

3. 反向传播

最后,代码通过accelerator.backward(loss)执行反向传播,计算梯度。total_loss用于累积损失,以便记录每个步骤的损失值。

total_loss += loss.detach().float()
accelerator.backward(loss)

loss.detach().float()表示将损失从计算图中分离,并将其转换为浮动类型,以便进行记录和后续的反向传播。

小结
  • Loss的计算方式:如果选择sum,会对每个token的损失进行求和,而不是取平均,以确保每个token对梯度更新的贡献相等。这对于大批量训练非常重要,能够避免在梯度累积过程中出现误差。

  • 梯度累积:通过逐步计算小批量的梯度并累积,能够在不增加显存使用的情况下模拟大批量训练。需要确保在每次累积时正确计算损失,并进行适当的缩放。

这段代码的设计和实现考虑到了梯度累积和损失计算中的数值稳定性问题,确保在大批量训练时能够正确更新模型参数。

梯度累积公式解释

在梯度累积的公式中,( L k L_k Lk ) 表示第 ( k k k ) 步梯度累积中的损失值。具体来说,假设我们将训练数据分成 ( G G G ) 个小批量(mini-batches)进行梯度累积,那么在第 ( k k k ) 步(即第 ( k k k ) 个小批量)中,我们计算的损失为 ( L k L_k Lk )。这些损失是通过模型对第 ( k k k ) 小批量的数据进行前向传播后得到的。

梯度累积的背景

梯度累积的目的是在多个小批量上计算梯度,并在一定数量的步数之后再进行一次反向传播。这样,我们可以模拟更大的批量大小,从而提高训练稳定性和效率,而不需要一次性处理过大的批量数据。

解释公式

公式中的 ( L k L_k Lk ) 是每个小批量的损失值,具体来说,它是模型在第 ( k k k ) 步计算得到的损失。公式可以表示为:

L accumulated = ∑ k = 1 G L k L_{\text{accumulated}} = \sum_{k=1}^{G} L_k Laccumulated=k=1GLk

这意味着我们将每个小批量(即每个 ( k k k ))的损失 ( L k L_k Lk ) 累积起来,直到进行 ( G G G ) 步梯度累积。此时的累积损失 ( L accumulated L_{\text{accumulated}} Laccumulated ) 还没有进行任何梯度更新操作,它只是所有小批量损失的总和。

举个例子

假设我们有一个批量大小为 8 的训练数据集,并且我们希望模拟一个批量大小为 32 的训练过程,但由于显存限制,我们决定使用梯度累积来分 4 个小批量来计算梯度:

  1. 第一步:计算第一小批量(batch 1)的损失 ( L 1 L_1 L1 )。
  2. 第二步:计算第二小批量(batch 2)的损失 ( L 2 L_2 L2 )。
  3. 第三步:计算第三小批量(batch 3)的损失 ( L 3 L_3 L3 )。
  4. 第四步:计算第四小批量(batch 4)的损失 ( L 4 L_4 L4 )。

此时,梯度累积的总损失就可以表示为:

L accumulated = L 1 + L 2 + L 3 + L 4 L_{\text{accumulated}} = L_1 + L_2 + L_3 + L_4 Laccumulated=L1+L2+L3+L4

然后,我们对这个累积损失进行一次反向传播,更新模型参数。

为什么要进行梯度累积?

梯度累积的好处是可以模拟大批量训练而不会导致显存溢出。通过分步计算损失并累积梯度,我们可以在不增加显存需求的情况下使用较大的有效批量大小进行训练。

损失值求和的例子

通过一个简单的例子来说明为什么在使用大批量训练时,损失值求和的方式比损失值平均更能保证每个样本都以相同的权重参与模型更新。

背景说明

在大批量训练中,损失的平均损失的求和对训练结果的影响不同。具体来说,如果我们采用损失平均,每个样本对模型更新的贡献会受到批次大小的影响;而如果我们采用损失求和,每个样本对更新的贡献将保持一致,无论批次大小有多大。

例子解释

假设我们有两个不同的批次,分别包含不同数量的样本:

  • 批次 A:包含 4 个样本。
  • 批次 B:包含 8 个样本。

假设这两个批次的损失分别如下:

  • 批次 A 的损失:每个样本的损失分别为 2, 3, 1, 4,合计为 10。
  • 批次 B 的损失:每个样本的损失分别为 1, 2, 3, 1, 2, 1, 4, 3,合计为 20。
如果使用损失的平均值
  • 批次 A:平均损失为 ( 10 4 = 2.5 \frac{10}{4} = 2.5 410=2.5 )。
  • 批次 B:平均损失为 ( 20 8 = 2.5 \frac{20}{8} = 2.5 820=2.5 )。

如果我们采用损失平均的方式,两个批次的损失平均值相同。因此,无论是批次 A 还是批次 B,它们对模型的训练贡献是相同的。也就是说,每个样本的贡献被"平等化"了。这个方法在样本数差异较大的情况下可能会引入偏差,特别是当批次大小不一致时,每个批次的影响可能不成比例。

如果使用损失的求和
  • 批次 A:总损失为 10。
  • 批次 B:总损失为 20。

在这种情况下,两个批次的损失对模型的影响是按批次大小加权的。具体来说,批次 B 有更多的样本,所以它的损失总和较大,也就意味着在梯度更新时,批次 B 会对模型参数的更新产生更大的影响。如果我们将这两个批次的损失求和,最终总损失会更大,这会导致模型在训练过程中更多地"倾向"于批次 B,尤其是当使用梯度累积时,批次 B 会对模型产生更大的影响。

为什么损失求和更有意义?

通过将损失求和,确保每个样本的损失对模型更新的影响是等量的。而使用平均损失时,由于批次大小不同,可能会出现以下问题:

  • 样本较多的批次(如批次 B)会占据更大的权重,因为平均值会"稀释"了损失的影响,使得较大的批次对整体损失的影响降低。
  • 样本较少的批次(如批次 A)则可能被忽视,因为它对最终的损失计算贡献较小。

梯度累积中的影响

在梯度累积中,采用损失求和可以避免不同批次间因批次大小差异而导致的权重不均的问题。这样,在进行多次梯度累积时,不同批次的损失就可以通过求和确保每个样本都具有相同的影响力。

举个具体的例子:

假设我们有 2 个小批量,分别为 A 和 B,批次大小分别为 4 和 8。每个批次的总损失分别为 10 和 20。如果我们直接进行平均计算,得到的损失为:

L avg = 10 + 20 4 + 8 = 30 12 = 2.5 L_{\text{avg}} = \frac{10 + 20}{4 + 8} = \frac{30}{12} = 2.5 Lavg=4+810+20=1230=2.5

而如果我们进行损失求和,得到的损失为:

L sum = 10 + 20 = 30 L_{\text{sum}} = 10 + 20 = 30 Lsum=10+20=30

在进行梯度累积时,采用求和损失可以避免批次大小的不均衡影响,从而确保每个样本的贡献是等量的。

如果batch size都一样的话,还会有这个问题吗?

如果batch size都一样的话,还会有这个问题吗?现实中大规模训练batch size开的不一样吗?

如果batch size 都一样,那么使用损失的平均或者求和不会产生太大的差异。因为在这种情况下,每个批次对模型更新的贡献是均衡的,无论是计算损失的平均值还是求和,最终对模型参数的影响都是一致的。

如果 batch size 相同:

  • 平均损失:损失的平均值会反映每个样本的损失在训练中的影响,因为每个样本的数量是相同的,计算平均值不会受到影响。
  • 求和损失:由于批次大小相同,求和后的损失也是一个与批次大小成正比的值,而在训练过程中,更新的频率和力度会基于每个批次的大小,进而影响训练的稳定性和收敛性。

所以,如果每个批次大小一样,通常来说,不会出现上述问题。两种方式(平均损失和求和损失)的训练效果是差不多的。


现实中大规模训练的情况

在现实中的大规模训练中,批次大小(batch size)通常是不同的,这取决于多个因素,例如:

  1. 显存限制:不同的硬件设备(如 GPU、TPU)的显存大小不同,支持的 batch size 也不同。在显存较小的设备上,可能无法一次性处理大量数据,因此会选择较小的批次进行梯度累积,而在显存较大的设备上则可以使用较大的 batch size 进行训练。

  2. 分布式训练:在多卡(multi-GPU)或者多机(multi-node)训练的情况下,为了充分利用计算资源,批次大小也会有所不同。通常,跨多个 GPU 或节点分布训练时,为了平衡每个卡上的负载,可能会采用不同的 batch size。

  3. 性能优化:有时为了加速训练,训练人员可能会根据任务的需要进行优化,选择不同的批次大小以提升计算效率。例如,在某些任务中,增大批次大小可以提高并行效率,但也可能导致训练不稳定,这时就需要进行梯度累积。

  4. 模型设计:有些模型架构(比如需要大输入的模型)可能会需要更大的批次以适应内存需求,或者为了更高效地训练。


在大批量训练中为什么使用梯度累积?

即便batch size 不同,如果我们需要模拟一个大批量训练的效果,并且又不能增加显存的使用,我们通常会使用梯度累积来应对这个问题。

  1. 模拟大批量效果:如果显存不足,无法处理大批量数据,我们可以通过梯度累积模拟大批量训练的效果。通过在多个小批量上计算损失并累积梯度,我们可以达到类似大批量训练的效果,避免了显存溢出的问题。

  2. 不同批次大小的影响:如果不同批次的大小不一致,在计算损失时,损失求和的方式就显得更为重要了,因为批次大小较大的批次对总损失的贡献较大,而平均损失的计算可能会削弱这种差异。而损失求和方式则会保证较大批次对总损失的影响相对较大,从而保持了每个样本的影响力一致。


总结

  • 如果batch size 相同,那么损失求和和损失平均不会有太大差异。
  • 现实中大规模训练确实可能会有不同的 batch size,尤其是在显存和计算资源有限的情况下。为了应对这个问题,梯度累积是一个常见的解决方案,用于模拟大批量训练的效果。
  • 在使用不同大小的 batch size 时,损失求和能够保证每个样本在训练中的权重一致,而损失平均可能会导致较大批次的样本对训练过程的影响被稀释。

为什么要除以 m k m_k mk,直接用 L k L_k Lk不行吗?

这是一个非常重要且值得深入思考的问题。让我们仔细分析为什么在梯度累积时,每个小批量的损失还需要除以 ( m k m_k mk),即有效token数。

1. 有效token数量的影响

首先需要理解的是,损失 ( L k L_k Lk) 通常是基于每个小批量的有效token计算的。每个小批量中的有效token数量(去除填充后的部分)可能是不同的,因此直接用 ( L k L_k Lk) 来累积会导致训练时某些小批量(包含更多有效token的批次)对最终梯度的贡献过大。

例如:

  • 假设有两个批次:批次 1 包含 3 个有效token,批次 2 包含 6 个有效token。
  • 如果我们直接将 ( L 1 L_1 L1) 和 ( L 2 L_2 L2) 累加,批次 2 的损失会占据过多权重,因为它包含更多有效token。这样会导致更新偏向于包含更多token的批次,从而影响梯度的准确性。

2. 损失求和和平均的逻辑

不除以 ( m k m_k mk) 的问题:
  • 如果不除以有效token数 ( m k m_k mk),我们实际上在做损失累积时忽视了每个批次的“贡献”大小,特别是当批次中有效token数目差异较大时。这样会使得损失计算不均衡。
  • 损失 ( L k L_k Lk) 并不直接等于每个token的损失之和,它实际上是对有效token的损失的总和。所以,批次大小的不同会导致损失的总和不成比例地反映出每个批次的有效学习。
除以 ( m k m_k mk) 的好处:
  • 通过除以每个小批量的有效token数 ( m k m_k mk),我们实际上是在对每个小批量的损失进行“加权平均”。
  • 这样做的目的就是保证每个token对最终梯度更新的贡献是公平的,不会因为某个小批量有更多有效token而对最终梯度产生不合理的影响。

3. 为什么需要这样做?

假设你使用了梯度累积,目标是模拟大批量训练的效果。如果每个小批量中有效token的数量不同,直接将损失加起来会导致更大的小批量(含有更多有效token的批次)对最终梯度更新产生更大的影响。为了避免这种偏差,我们需要对每个小批量的损失进行标准化处理,即通过除以 ( m k m_k mk) 来校准每个小批量的损失。

数学解释

  • 公式:

    L final = 1 G ∑ k = 1 G L k m k L_{\text{final}} = \frac{1}{G} \sum_{k=1}^{G} \frac{L_k}{m_k} Lfinal=G1k=1GmkLk

    这里的 ( L k m k \frac{L_k}{m_k} mkLk ) 表示对每个小批量的损失进行标准化,确保每个小批量在梯度累积中的贡献是按有效token数量来加权的,而不是简单地将所有小批量的损失直接加起来。

举个简单的例子

假设我们有两个批次:

  • 批次 1:3 个有效token,损失 ( L 1 = 1.5 L_1 = 1.5 L1=1.5)
  • 批次 2:6 个有效token,损失 ( L 2 = 3.0 L_2 = 3.0 L2=3.0)

如果我们直接累加损失:
L accumulated = L 1 + L 2 = 1.5 + 3.0 = 4.5 L_{\text{accumulated}} = L_1 + L_2 = 1.5 + 3.0 = 4.5 Laccumulated=L1+L2=1.5+3.0=4.5

但是,如果我们将损失除以每个批次的有效token数:
L final = 1 2 ( L 1 m 1 + L 2 m 2 ) = 1 2 ( 1.5 3 + 3.0 6 ) = 1 2 ( 0.5 + 0.5 ) = 0.5 L_{\text{final}} = \frac{1}{2} \left( \frac{L_1}{m_1} + \frac{L_2}{m_2} \right) = \frac{1}{2} \left( \frac{1.5}{3} + \frac{3.0}{6} \right) = \frac{1}{2} \left( 0.5 + 0.5 \right) = 0.5 Lfinal=21(m1L1+m2L2)=21(31.5+63.0)=21(0.5+0.5)=0.5

这样做的目的是确保每个批次(无论它的大小如何)对损失的贡献是公平的。如果没有这种标准化,损失较大的批次会导致最终的梯度更新不准确,偏向于更多token的批次。

总结:

除以 ( m k m_k mk) 的目的是为了对损失进行标准化处理,确保每个小批量的有效token对梯度累积的贡献是均衡的,特别是在批次之间有效token数量差异较大的情况下。这样可以避免某些批次因包含更多有效token而在梯度更新中占据不成比例的权重,从而使训练更加稳定和准确。

后记

2025年1月18日15点45分于上海, 在OpenAI o1大模型辅助下完成。

;