概要
前文已经简单介绍梯度,本文主要介绍大语言模型中使用数值梯度的方法实现 损失值 L L L 对模型权重矩阵的梯度计算,而不是传统的链式法则进行梯度计算。如果想要理解整体计算方式,先明白损失值 L L L的计算方式,通过公式了解其和权重矩阵 W V W_V WV的关系。然后再理解损失值 L L L对权重矩阵 W V W_V WV的梯度计算。
1. 数值梯度的公式
数值梯度通过有限差分法近似计算梯度,对权重矩阵
W
V
W_V
WV 中每个元素的梯度
∂
L
∂
W
V
i
j
\frac{\partial L}{\partial W_{V_{ij}}}
∂WVij∂L:
∇
L
W
V
i
j
=
L
p
l
u
s
−
L
c
u
r
r
e
n
t
h
\nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h}
∇LWVij=hLplus−Lcurrent
其中,每个参数的含义在下文中有讲解。
2. 数值梯度计算过程
(1) 初始化
- 给定权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WV∈Fm×n,与 W V W_V WV大小相同的梯度矩阵 ∇ L W V = zeros ( m , n ) \nabla L_{W_V} = \text{zeros}(m, n) ∇LWV=zeros(m,n)。
- 确定增量 h h h 的值(如 h = 1 0 − 5 h=10^{−5} h=10−5)。
(2) 遍历权重矩阵的每个元素
对于
W
V
W_V
WV中的每个元素
W
V
i
j
W_{V_{ij}}
WVij:
- 创建一个单位矩阵 E i j E_{ij} Eij,大小与 W V W_V WV相同,且 E i j = 1 E_{ij}=1 Eij=1。
- 计算损失值:
-
L
p
l
u
s
=
L
(
W
v
+
h
∗
E
i
j
)
L_{plus}=L(W_v+h*E_{ij})
Lplus=L(Wv+h∗Eij):
- 在 W V W_V WV的第 ( i , j ) (i,j) (i,j) 元素增加一个微小值 h h h,得到新的权重矩阵,然后计算损失值 L p l u s L_{plus} Lplus.
-
L
c
u
r
r
e
n
t
=
L
(
W
v
)
L_{current}=L(W_v)
Lcurrent=L(Wv):
- 使用当前的权重矩阵
W
V
W_V
WV计算损失值
L
c
u
r
r
e
n
t
L_{current}
Lcurrent。
- 使用当前的权重矩阵
W
V
W_V
WV计算损失值
L
c
u
r
r
e
n
t
L_{current}
Lcurrent。
(3) 梯度估算
通过有限差分公式,计算第
(
i
,
j
)
(i,j)
(i,j)元素的梯度:
∇
L
W
V
i
j
=
L
p
l
u
s
−
L
c
u
r
r
e
n
t
h
\nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h}
∇LWVij=hLplus−Lcurrent
这个公式的含义是:通过观察
W
V
i
j
W_{V_{ij}}
WVij 增加
h
h
h 后损失函数的变化,我们可以估算出损失函数对该参数的敏感程度(梯度)。
3. 数值梯度的特点
优点:
- 简单直观:无需解析推导梯度公式,直接利用损失函数计算。
- 适合验证解析梯度:可以作为解析梯度的参考标准,用于检测实现是否正确。
缺点:
- 计算效率低:
- 对于权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WV∈Fm×n,需要计算 m × n m×n m×n 次损失。
- 如果网络规模较大,数值梯度的计算会非常耗时。
- 数值误差:
- 梯度近似的精度取决于 h h h 的选择。
- h h h 太大会导致误差较大, h h h 太小可能引入浮点数精度问题。