在大模型训练过程中,意外中断(如服务器重启、显存溢出等)是常见问题。如何优雅地恢复训练进度,避免从头再来?LLamaFactory 基于 HuggingFace Transformers,天然支持断点续训(resume from checkpoint)。本文以 Qwen2VL 为例,详细介绍其断点续训机制、配置方法及源码实现位置。
LLamaFactory 的训练流程底层调用了 HuggingFace 的 Trainer
/Seq2SeqTrainer
,其 train()
方法原生支持 resume_from_checkpoint 参数。只要在训练时定期保存 checkpoint,意外中断后即可从最近的 checkpoint 恢复训练,继续未完成的 epoch 和 step。
以 examples/train_full/qwen2_5vl_full_sft.yaml
为例:
yaml展开代码output_dir: saves/qwen2_5vl-7b/full/sft # checkpoint 保存目录
save_steps: 500 # 每500步保存一次
resume_from_checkpoint: null # 断点续训参数,默认null
只需配置好 output_dir
和 save_steps
,训练过程中会自动生成如 checkpoint-500
、checkpoint-1000
等子目录。
假设训练中断,最新 checkpoint 路径为 saves/qwen2_5vl-7b/full/sft/checkpoint-1500
,只需修改配置文件:
yaml展开代码resume_from_checkpoint: saves/qwen2_5vl-7b/full/sft/checkpoint-1500
或者直接写
yaml展开代码resume_from_checkpoint: true
(此时系统会自动查找 output_dir 下最新的 checkpoint)
然后用和之前一样的命令重新启动训练即可,系统会自动加载模型、优化器、调度器等全部状态,继续未完成的训练。
examples/train_full/qwen2_5vl_full_sft.yaml
src/llamafactory/hparams/parser.py
src/llamafactory/train/sft/workflow.py
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
src/llamafactory/train/sft/trainer.py
CustomSeq2SeqTrainer
继承自 HuggingFace 的 Seq2SeqTrainer
,直接复用其断点续训机制本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!