2025-03-03
深度学习
00

目录

基础环境
下载预训练模型
GRPO训练

VLM-R1 是deepseek R1 的GRPO训练方式在VLM中的实现。项目地址:https://github.com/om-ai-lab/VLM-R1

基础环境

构建docker环境:

bash
# docker pull pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel docker pull pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel # docker run -it --gpus all pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel bash docker run -it --gpus all pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel bash apt update apt install vim git -y git clone https://github.com/om-ai-lab/VLM-R1.git cd VLM-R1/ bash setup.sh pip install transformers==4.49.0

此docker环境已经被我上传为 kevinchina/deeplearning:2.5.1-cuda12.4-cudnn9-devel-vlmr1

使用此镜像:

docker run -it --gpus '"device=0,1,2,3,4,5,6,7"' --shm-size=64g -v /data/xiedong:/data/xiedong --net host kevinchina/deeplearning:2.5.1-cuda12.4-cudnn9-devel-vlmr1 bash

下载预训练模型

pip install modelscope # 下载到这个目录 /data/xiedong# modelscope download --model Qwen/Qwen2.5-VL-3B-Instruct --local_dir Qwen2.5-VL-3B-Instruct

GRPO训练

准备数据:

  1. wget https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/train2014.zip

  2. wget https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/rec_jsons_processed.zip

  3. 在src/open-r1-multimodal/data_config/rec.yaml中写入数据路径

vim src/open-r1-multimodal/data_config/rec.yaml

datasets: - json_path: /data/xiedong/rec_jsons_processed/refcoco_train.json - json_path: /data/xiedong/rec_jsons_processed/refcocop_train.json - json_path: /data/xiedong/rec_jsons_processed/refcocog_train.json
  1. 运行训练bash src/open-r1-multimodal/run_scripts/run_grpo_rec.sh

这文件里面长这样: vim src/open-r1-multimodal/run_scripts/run_grpo_rec.sh

cd src/open-r1-multimodal export DEBUG_MODE="true" # export CUDA_VISIBLE_DEVICES=4,5,6,7 RUN_NAME="Qwen2.5-VL-3B-GRPO-REC" export LOG_PATH="./debug_log_$RUN_NAME.txt" torchrun --nproc_per_node="8" \ --nnodes="1" \ --node_rank="0" \ --master_addr="127.0.0.1" \ --master_port="12346" \ src/open_r1/grpo_rec.py \ --deepspeed local_scripts/zero3.json \ --output_dir output/$RUN_NAME \ --model_name_or_path /data/xiedong/Qwen2.5-VL-3B-Instruct \ --dataset_name data_config/rec.yaml \ --image_root /data/xiedong \ --max_prompt_length 1024 \ --num_generations 8 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 2 \ --logging_steps 1 \ --bf16 \ --torch_dtype bfloat16 \ --data_seed 42 \ --report_to wandb \ --gradient_checkpointing false \ --attn_implementation flash_attention_2 \ --num_train_epochs 2 \ --run_name $RUN_NAME \ --save_steps 100 \ --save_only_model true

训练时候:

image.png

  1. torchrun: 一个用于分布式训练的 PyTorch 脚本运行工具。

  2. --nproc_per_node="8": 指定每个节点上的进程数为 8。

  3. --nnodes="1": 指定训练使用的节点数为 1。

  4. --node_rank="0": 指定当前节点的排名,这里是0号节点。

  5. --master_addr="127.0.0.1": 指定 master 节点的地址,用于分布式训练的通信。

  6. --master_port="12346": 指定 master 节点的端口号。

  7. src/open_r1/grpo_rec.py: 运行的 python 脚本。

  8. --deepspeed local_scripts/zero3.json: 使用 DeepSpeed 优化器并指定配置文件 local_scripts/zero3.json

  9. --output_dir output/$RUN_NAME: 指定输出目录,训练结果将保存到 output/$RUN_NAME

  10. --model_name_or_path /data/xiedong/Qwen2.5-VL-3B-Instruct: 指定预训练模型的路径。

  11. --dataset_name data_config/rec.yaml: 指定数据集配置文件的路径。

  12. --image_root /data/xiedong: 指定图像数据根目录。

  13. --max_prompt_length 1024: 指定生成模型输入的最大 prompt 长度。

  14. --num_generations 8: 指定每个输入生成的输出序列数。

  15. --per_device_train_batch_size 1: 指定每个设备上的训练批次大小为 1。

  16. --gradient_accumulation_steps 2: 指定梯度积累的步数,即在执行反向传播之前累积的梯度步数。

  17. --logging_steps 1: 每隔多少步进行一次日志记录。

  18. --bf16--torch_dtype bfloat16: 指定使用 bfloat16 数据类型进行训练,以更好地利用 GPU 的计算能力。

  19. --data_seed 42: 指定用于数据乱序的随机种子。

  20. --report_to wandb: 使用 Weights & Biases 工具进行实验跟踪和报告。

  21. --gradient_checkpointing false: 指定是否使用梯度检查点技术以减少显存使用。这里设置为 false。

  22. --attn_implementation flash_attention_2: 指定注意力机制的实现方法。

  23. --num_train_epochs 2: 指定训练的总轮数为 2。

  24. --run_name $RUN_NAME: 指定运行名称,用于标识当前实验。

  25. --save_steps 100: 指定每隔 100 步保存一次模型。

  26. --save_only_model true: 指定仅保存模型权重,不保存优化器状态等信息。

多个数据文件和图像文件夹可以使用":"作为分隔符指定:

--data_file_paths /path/to/data1.jsonl:/path/to/data2.jsonl \ --image_folders /path/to/images1/:/path/to/images2/
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!