在深度学习模型训练过程中,遇到loss变成NaN(Not a Number)是一个常见的问题。这种情况通常表明训练过程中出现了数值不稳定性,需要及时处理以避免模型训练失败。以下是这种现象的原因分析和解决方法。
过高的学习率可能导致优化过程不稳定,使梯度爆炸,最终导致loss变为NaN。
当梯度变得非常大时,权重更新会变得过大,导致模型参数迅速发散到无穷大或NaN。
输入数据中的异常值、缺失值或未标准化的数据可能导致计算过程中出现数值不稳定。
某些激活函数(如指数函数)在输入值过大时可能导致数值溢出。
计算过程中的除零操作会导致NaN。
某些网络结构可能天生不稳定,特别是在深层网络中。
python# 原来的学习率设置
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 降低学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 或更低
python# PyTorch中实现梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
python# 数据标准化
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
python# 替换不稳定的激活函数
# 例如,用ReLU替代Sigmoid在某些情况下更稳定
model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(), # 使用ReLU而非Sigmoid
nn.Linear(hidden_size, output_size)
)
python# 在网络层之间添加BatchNorm
model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
python# 使用更合适的权重初始化方法
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
model.apply(init_weights)
python# 检查输入数据中的NaN
import numpy as np
print(np.isnan(X_train).any())
# 替换NaN值
X_train = np.nan_to_num(X_train)
python# PyTorch中使用混合精度训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for epoch in range(epochs):
for batch in dataloader:
optimizer.zero_grad()
# 使用autocast进行前向传播
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# 使用scaler进行反向传播和优化
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
python# 在训练循环中添加检查
for epoch in range(epochs):
for batch in dataloader:
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, targets)
# 检查loss是否为NaN
if torch.isnan(loss):
print("NaN loss detected! Skipping batch...")
continue
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!