作者 | asher
许多关于 DeepSeek R1 的复现文章,主要聚焦在“rewards的设计、训练指标的变化、benchmark测评”这些内容,但是对于“本地训练”这个开启深度探索的关键前置步骤,却很少有人深挖。
可能有人觉得,照着readme操作就能轻松训练了吧?太天真啦!实际动手就会发现,和自家的环境各种水土不服,大模型不是训不起来就是训的太慢,问题多到让人头大。
为了解决本地训练的适配性问题,今天挑选HuggingFace的开源项目open-r1,为大家带来一场全流程实操演示,从怎么在8卡A100(40G)上跑通基于Qwen-14B的DeepSeek R1复现,到分享超实用的环境镜像,还有满满踩坑经验,再到手把手教你改造代码适配自己的任务数据,助大家光速开启DeepSeek R1在自定义数据上的训练探索之旅。
一、 环境搭建不求人
1. 显卡驱动与CUDA适配要点
open-r1明确要求cuda12.4,得先瞅瞅自己机器的显卡驱动版本(如下图),要是版本太老,那可就得升级才能适配适配cuda12.4,我亲测,显卡驱动版本为470以上就能正常运行,我的版本是535。
# 查看自己的显卡版本与cuda是否适配 import torch print(torch.cuda.is_available()) # True就可以
2. 快速搞定环境安装
与readme里的uv相比,我还是习惯使用conda管理虚拟环境:
复制1. conda create -n openr1 python=3.11 2. pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 3. pip install vllm==0.7.2 4. pip install flash-attn 5. 切换到open-r1目录执行pip install -e ".[dev]"
二、训练踩坑大避雷
1. 导致OOM的原因有这么多
以grpo训练为例,使用Qwen-14B在A100上训练很容易报错OOM,原因有多种,让我来为大家一一分析:grpo任务可以分为两部分:一部分是模型训练(7卡),一部分是模型推理(1卡),OOM报错的原因就来自这两部分。
- 训练报错oom:7张A100卡无法实现14B模型的训练。解决方法:修改recipes/accelerate_configs/zero3.yaml,开启offload
- 推理报错oom:如果vllm版本在0.7.3以下,很容易发生oom,需要修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,调低vllm_gpu_memory_utilization参数值,14B模型可以改为0.2,7B模型可以改为0.5。
- 推理报错oom:指定vllm推理的max_model_len太长,导致kv caceh需要占用的显存太多。解决方法:修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,调低vllm_max_model_len,注意这个参数是指prompt+模型输出长度,不宜过短,可以调整为4k-8k。默认值是读取基座模型config,比如Qwen-14B默认是32768。
那么如何识别自己的OOM报错是出自训练还是推理呢?直接看报错的GPU卡号,因为默认是最后一张卡用于推理,如下图既然是GPU 7 内存不足,那就推理出了问题,只需要调整上述提到的两个参数即可。
针对Qwen-14B在8卡A100(40G)训练对应的配置文件,我已经调教好了放在本文最后,供大家参考。
2. reward函数的形参命名有讲究
在设计reward函数,有个注意:reward函数声明的形参很重要,不是随便起的,要求与dataset的列名是一致的。比如下面这个reawrd函数的两个形参,completions表示模型生成的内容,ground_truth表示dataset中”ground_truth“列的值,这里的形参ground_truth就是要求与dataset列名字对齐。
复制import re def reward_func(completions, ground_truth, **kwargs): # Regular expression to capture content inside \boxed{} matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions] contents = [match.group(1) if match else "" for match in matches] # Reward 1 if the content is the same as the ground truth, 0 otherwise return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
三、DeepSeek R1训练快速开启不迷路
1. 数据先行!准备业务数据要点
离线构造业务数据集data.json,注意字段名为problem与solution,与官方给的示例数据字段名一致,这样可以少去很多改代码的麻烦:
复制
{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was okay.\nSentiment:\n", "solution": "positive"} {"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was shit.\nSentiment:\n", "solution": "negative"}
2. 巧妙变身!轻松更改数据读取方式
修改grpo.py中数据读取方式,由读取hub数据改为读取离线数据:
复制
dataset = load_dataset("json", data_files=XXX/data.json) dataset = dataset["train"].train_test_split(test_size=0.02)
个性定制!手把手自定义reward函数 注意这里函数声明中solution形参要与dataset的字段保持一致:
复制def accuracy_reward_ours(completions, solution, **kwargs): """Reward function that checks if the completion is the same as the ground truth.""" contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol in zip(contents, solution): gold_parsed = sol # 从数据集中读取ground-truth if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = re.findall("<answer>(.*?)</answer>",content) # 从模型输出文本中提取预测答案 if len(answer_parsed)>0: answer_parsed = answer_parsed[0] reward = float(1 if answer_parsed==gold_parsed else 0) # 判断预测结果与真实结果是否一致 else: reward = float(0) else: # If the gold solution is not parseable, we reward 1 to skip this example reward = 1.0 print("Failed to parse gold solution: ", sol) rewards.append(reward) return rewards
3. 一键启动!畅爽开启DeepSeek R1训练
复制ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \ --num_processes=7 src/open_r1/grpo.py \ --config recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml \ &> /workspace/user_code/Qwen2.5-14B-Instruct.log
四、能让14B模型在A100上丝滑跑通R1的配置参数大公开
recipes/accelerate_configs/zero3.yaml
复制compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: "cpu" offload_param_device: "cpu" zero3_init_flag: true zero3_save_16bit_model: true zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false
recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml
复制# Model arguments model_name_or_path: XXX/models/Qwen2.5-14B-Instruct model_revision: main torch_dtype: bfloat16 attn_implementation: flash_attention_2 # Data training arguments dataset_name: XXX/dataset/data.json # Num processes is less by 1 as vLLM is using 1 GPU num_processes: 7 # GRPO trainer config reward_funcs: - accuracy_ours - format bf16: true use_vllm: true vllm_device: cuda:7 vllm_gpu_memory_utilization: 0.2 # vllm版本在0.7.3以下 vllm_max_model_len: 8000 do_eval: true eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false hub_model_id: Qwen-2.5-7B-Simple-RL hub_strategy: every_save learning_rate: 3.0e-06 log_level: info logging_steps: 5 logging_strategy: steps lr_scheduler_type: cosine max_prompt_length: 512 max_completion_length: 1024 max_steps: -1 num_generations: 7 num_train_epochs: 1 output_dir: XXX/Qwen-2.5-7B-Instruct-RL overwrite_output_dir: true per_device_eval_batch_size: 8 per_device_train_batch_size: 8 push_to_hub: false report_to: "none" save_strategy: "steps" save_steps: 100 save_total_limit: 2 seed: 42 warmup_ratio: 0.1