展开代码实际batch_size = per_device_train_batch_size × gradient_accumulation_steps × 设备数量 您的设置: - per_device_train_batch_size = 12 - gradient_accumulation_steps = 2 - 假设8张卡:实际batch_size = 12 × 2 × 8 = 192
python展开代码# gradient_accumulation_steps = 1 (默认)
for batch in dataloader:
loss = model(batch)
loss.backward() # 每次都通信
optimizer.step() # 每次都同步梯度
# gradient_accumulation_steps = 4
for i in range(4):
loss = model(batch[i])
loss.backward() # 只累积,不通信
optimizer.step() # 4次累积后才通信一次
gradient_accumulation_steps
个批次才同步一次本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!