Bootstrap

大语言模型---通过数值梯度的方式计算损失值L对模型权重矩阵W的梯度;数值梯度的公式;数值梯度计算过程

概要

前文已经简单介绍梯度,本文主要介绍大语言模型中使用数值梯度的方法实现 损失值 L L L 对模型权重矩阵的梯度计算,而不是传统的链式法则进行梯度计算。如果想要理解整体计算方式,先明白损失值 L L L的计算方式,通过公式了解其和权重矩阵 W V W_V WV的关系。然后再理解损失值 L L L对权重矩阵 W V W_V WV的梯度计算。

1. 数值梯度的公式

数值梯度通过有限差分法近似计算梯度,对权重矩阵 W V W_V WV 中每个元素的梯度 ∂ L ∂ W V i j \frac{\partial L}{\partial W_{V_{ij}}} WVijL
∇ L W V i j = L p l u s − L c u r r e n t h \nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h} LWVij=hLplusLcurrent

其中,每个参数的含义在下文中有讲解。

2. 数值梯度计算过程

(1) 初始化

  • 给定权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WVFm×n,与 W V W_V WV大小相同的梯度矩阵 ∇ L W V = zeros ( m , n ) \nabla L_{W_V} = \text{zeros}(m, n) LWV=zeros(m,n)
  • 确定增量 h h h 的值(如 h = 1 0 − 5 h=10^{−5} h=105)。

(2) 遍历权重矩阵的每个元素
对于 W V W_V WV中的每个元素 W V i j W_{V_{ij}} WVij

  1. 创建一个单位矩阵 E i j E_{ij} Eij,大小与 W V W_V WV相同,且 E i j = 1 E_{ij}=1 Eij=1
  2. 计算损失值:
  • L p l u s = L ( W v + h ∗ E i j ) L_{plus}=L(W_v+h*E_{ij}) Lplus=L(Wv+hEij)
    • W V W_V WV的第 ( i , j ) (i,j) (i,j) 元素增加一个微小值 h h h,得到新的权重矩阵,然后计算损失值 L p l u s L_{plus} Lplus.
  • L c u r r e n t = L ( W v ) L_{current}=L(W_v) Lcurrent=L(Wv):
    • 使用当前的权重矩阵 W V W_V WV计算损失值 L c u r r e n t L_{current} Lcurrent

(3) 梯度估算
通过有限差分公式,计算第 ( i , j ) (i,j) (i,j)元素的梯度:
∇ L W V i j = L p l u s − L c u r r e n t h \nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h} LWVij=hLplusLcurrent
这个公式的含义是:通过观察 W V i j W_{V_{ij}} WVij 增加 h h h 后损失函数的变化,我们可以估算出损失函数对该参数的敏感程度(梯度)。

3. 数值梯度的特点

优点:

  • 简单直观:无需解析推导梯度公式,直接利用损失函数计算。
  • 适合验证解析梯度:可以作为解析梯度的参考标准,用于检测实现是否正确。

缺点:

  1. 计算效率低
  • 对于权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WVFm×n,需要计算 m × n m×n m×n 次损失。
  • 如果网络规模较大,数值梯度的计算会非常耗时。
  1. 数值误差:
  • 梯度近似的精度取决于 h h h 的选择。
  • h h h 太大会导致误差较大, h h h 太小可能引入浮点数精度问题。
;