@[toc]【深度学习 】训练过程中loss出现nan
训练过程中loss出现nan
在深度学习中,loss 出现 NaN 通常是由数值不稳定或计算错误引起的。
1. 学习率过高
原因: 学习率
过大可能导致权重更新幅度过大
,引发数值不稳定。
解决方法: 降低学习率
,或使用学习率调度器逐步调整。
2. 数据问题
原因: 输入数据
包含 NaN 或 inf,或数据范围过大。
解决方法: 检查数据预处理
,确保数据标准化或归一化,并移除异常值。
3. 梯度爆炸
原因: 梯度值过大
,导致权重更新后出现 NaN。
解决方法: 使用梯度裁剪
(gradient clipping)限制梯度范围。
4. 损失函数问题
原因: 某些损失函数
(如对数损失)在输入接近零时可能产生 NaN。
解决方法: 检查损失函数输入
,避免极端值,或添加微小常数(如 1e-8)防止除零。
5. 权重初始化不当
原因: 权重初始化
不合适可能导致数值不稳定。
解决方法: 使用合适的初始化方法
(如 Xavier 或 He 初始化)。
6. 数值精度问题
原因: 使用低精度浮点数
(如 float16)可能引发数值不稳定。
解决方法: 尝试使用 float32 或 float64
提高精度。
7. 特定模块问题
原因: 某些模块可能由于输入
或参数问题导致 NaN。
解决方法: 检查这些模块
的输入和参数,确保数值合理。
8. 调试步骤
检查数据: 确保输入数据无异常。
检查损失函数: 确认输入值在合理范围内。
检查梯度: 使用调试工具(如 torch.autograd.gradcheck)检查梯度计算。
逐步调试: 逐层检查网络输出,定位问题模块。
9. 代码示例
import torch
import torch.nn as nn
import torch.optim as optim
# 示例模型
model = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
# 示例数据
inputs = torch.randn(32, 10)
targets = torch.randn(32, 1)
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练步骤
outputs = model(inputs)
loss = criterion(outputs, targets)
# 检查 loss 是否为 NaN
if torch.isnan(loss):
print("Loss is NaN. Checking gradients and inputs...")
# 进一步调试
optimizer.zero_grad()
loss.backward()
optimizer.step()