选自oxen.ai作者:Greg Schoeninger编译:陈陈、泽南RTX 3080 挪动版能练习哪种年夜模子?本文为那些 GPU 资本无限时应用 GRPO 练习的开辟者供给了可贵的领导。自 DeepSeek-R1 宣布以来,群组绝对战略优化(GRPO)因其无效性跟易于练习而成为年夜型言语模子强化进修的热点话题。R1 论文展现了怎样应用 GRPO 从遵守 LLM(DeepSeek-v3)的基础指令改变为推理模子(DeepSeek-R1)。GRPO 是一种在线进修算法(online learning algorithm),它经由过程应用练习进程中由练习模子本身天生的数据来停止迭代改良。GRPO 的目的是最年夜化天生补全(completions)的上风函数(advantage),同时确保模子坚持在参考战略(reference policy)邻近。本文的目标是帮你节俭一些时光,让你依据硬件估算抉择适合的模子巨细。在开端微调时,你必需做出的主要决议是抉择模子巨细,以及你是履行完整微调仍是参数高效微调(PEFT)。文章作者来自 AI 公司 Oxen.ai 的 CEO Greg Schoeninger。原文链接:https://www.oxen.ai/blog/grpo-vram-requirements-for-the-gpu-poor作者表现,他发明 trl 库中曾经有一个易于应用的 GRPO 实现,便破刻开端了练习,应用的硬件是装备了 16GB 显存的 Nvidia GeForce RTX 3080 的小型条记本电脑。正如各人可能碰到的成绩,作者发明示例代码中的参数设置招致了一个宏大的显存缺乏(OOM,out of memory )过错。torch.OutOfMemoryError: CUDA out of memory.Tried to allocate 1.90 GiB. GPU 0 has a total capacity of 15.73 GiB of which 1.28 GiB is free. Including non-PyTorch memory, this process has 14.43 GiB memory in use. Of the allocated memory 11.82 GiB is allocated by PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)现实应用情形作者表现,他们停止了一系列试验,以断定练习种种巨细的模子所需的显存(VRAM)请求。参数数目从 5 亿到 140 亿不等,他们比拟了权重的完整微调与参数高效微调(应用 LoRA),全部练习运转都在英伟达 H100 上实现,因而这里的 OOM 象征着 80GB 的 VRAM。在表格中,你能够找到 GSM8K 数据集上练习的前 100 步中的峰值内存应用情形。用于试验的模子是:全部试验均应用 Shadeform 的 GPU 市场实现,因而每次试验只要要破费多少美元 H100。试验成果标明,内存需要跟着模子巨细跟练习方法的差别而明显变更。比方,全参数微调比 PEFT 须要更多的内存。为什么 GRPO 对内存需要较高这要从 GRPO 的道理提及,这是它的流程图。GRPO 对内存需要较高的起因在于,其外部波及多个模子,而且在练习数据中每个查问会发生多个输出。上图中的战略模子、参考模子跟嘉奖模子各自都是一个须要停止推理的 LLM。(只管从技巧上讲,嘉奖模子可能不须要参数化,能够只是一个 Python 函数或正则表白式,但不影响 GRPO 对内存的高需要。)为什么 8-Bit 优化跟梯度检讨点有助于增加内存占用?平日来讲,练习一个年夜型言语模子须要在内存中存储三种重要范例的信息:模子参数、模子进修所需的梯度、优化器的跟踪数据。对上述内容咱们能够如许懂得:假如模子的参数占用了 X 的空间,那么梯度也会占用大概雷同的空间。而后,像 AdamW 如许的优化器须要更多的空间,由于它们就像一个记载员,跟踪近来的更新汗青,以便更好地决议将来的优化。为了加重这种内存累赘,平日采取两种技巧:起首,能够应用像 AdamW 如许的 8-bit 优化器版本,它们能更高效地存储跟踪数据,同时仍坚持精良的机能 —— 相似于紧缩照片能够节俭空间,同时保存年夜局部图像品质;其次,应用梯度检讨点技巧,这就像在练习进程中拍摄快照,而不是记载全部内容。固然这会使练习速率减慢约 20-30%,但它明显增加了内存应用。联合这些技巧,即便对 GPU 资本无限的人来说,也可能练习更年夜的模子。代码示例像 trl 如许的库曾经开端支撑 GRPO,使得微调由 transformers 形成的 LLM 变得十分简略。代码也十分简练,只要将练习器调换为 GRPOTrainer 并界说一些嘉奖即可。GRPO 的最小代码量大概只有 99 行,假如你应用的是像 meta-llama/Llama-3.2-1B-Instruct 如许的小型模子跟像 openai/GSM8K 如许的数据集,能够十分疾速地启动。trl 名目地点:https://github.com/huggingface/trl?ref=ghost.oxen.aiimport torchfrom datasets import load_dataset, Datasetfrom transformers import AutoTokenizer, AutoModelForCausalLMfrom trl import GRPOConfig, GRPOTrainerimport reSYSTEM_PROMPT = Respond in the following format: reasoning ... /reasoning answer ... /answer def extract_hash_answer(text: str) - str | None: if #### not in text: return None return text.split( #### )[1].strip()def get_gsm8k_questions(split = train ) - Dataset: data = load_dataset( openai/gsm8k , main )[split] data = data.map(lambda x: { prompt : [ { role : system , content : SYSTEM_PROMPT}, { role : user , content : x[ question ]} ], answer : extract_hash_answer(x[ answer ]) }) return datadef extract_xml_answer(text: str) - str: answer = text.split( answer )[-1] answer = answer.split( /answer )[0] return answer.strip()def format_reward_func(completions, **kwargs) - list[float]: Reward function that checks if the completion has a specific format. pattern = r ^ reasoning \n.*?\n /reasoning \n answer \n.*?\n /answer \n$ responses = [completion[0][ content ] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches]def accuracy_reward_func(prompts, completions, answer, **kwargs) - list[float]: Reward function that extracts the answer from the xml tags and compares it to the correct answer. responses = [completion[0][ content ] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]def main(): dataset = get_gsm8k_questions() model_name = meta-llama/Llama-3.2-1B-Instruct model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation= flash_attention_2 , device_map=None ).to( cuda ) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token training_args = GRPOConfig( output_dir= output , learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, weight_decay=0.1, warmup_ratio=0.1, lr_scheduler_type= cosine , logging_steps=1, bf16=True, per_device_train_batch_size=1, gradient_accumulation_steps=4, num_generations=4, max_prompt_length=256, max_completion_length=786, num_train_epochs=1, save_steps=100, save_total_limit=1, max_grad_norm=0.1, log_on_each_node=False, ) trainer = GRPOTrainer( model=model, processing_ > reward_funcs=[ format_reward_func, accuracy_reward_func ], args=training_args, train_dataset=dataset, ) trainer.train()if __name__ == __main__ : main()Num Generations 有什么用Num Generations 是一个超参数,它决议了咱们将在练习数据中对每个查问采样几多个补全。但是,这会明显增添 VRAM 的耗费。现在有一个开放的 GitHub 成绩,可能会辅助处理内存瓶颈成绩,能够参考如下链接地点:https://github.com/huggingface/trl/issues/2709?ref=ghost.oxen.ai对 num_completions=8,16,64 (DeepSeekMath 论文应用的 64),作者表现,不必再次盘算上述全部值,而是应用了 1B 参数模子停止了测试,以表现内存增加。不外,作者仍是倡议各人在内存瓶颈失掉修复之前应用 num_generations=4,也能取得不错的机能。影响 VRAM 的一些要素要对全部影响显存(VRAM)应用的要素停止片面的超参数验证,须要停止大批的试验。简略起见,这里只指出了须要留神的设置,以及试验中应用的详细数值。batch_size=1,因为 GRPO 为每个查问天生多个呼应,batch size 会敏捷掉控。gradient_accumulation_steps=4,优化器是另一个占用大批 VRAM 的处所。此参数决议了咱们将存储的梯度以辅助优化器停止其「登山」进程。num_completions=4,DeepSeekMath 论文中应用了 64。这完整超越了有些人的盘算估算。max_prompt_length=256,假如你想练习模子领有更年夜高低文的推理才能,将不得不增添 VRAM。GSM8K 的提醒绝对较小,合适此测试。max_completion_length=786,同样,因为盘算留神力的内存无限,推理链在这里遭到限度。高低文或天生的 token 越多,须要的内存就越年夜。LoRA target_modules=[ q_proj , k_proj , o_proj , up_proj , down_proj ] 在这方面能够实验多少种差别的迭代。target_modules= all-linear 是一种风行的方法,能够从你的 LoRA 中挤出最多的机能(就正确性而言)。对 VRAM 应用的大略预算假如你正在应用 FP16 精度停止练习,以下是一些简略的预算方式,能够辅助你懂得内存重要用在了哪些处所:模子参数:每个参数占用 2 字节。参考模子参数:每个参数占用 2 字节。梯度:每个参数占用 2 字节。优化器状况:每个参数占用 8 字节。8 位优化器:每个参数占用 4 字节。PEFT:有助于增加梯度的显存占用。最后是对于正确率的。作者实现了一个 10 亿参数的 Llama 3.2 模子的完全练习。在利用 GRPO 之前,该模子在保存测试集上到达了约 19% 的正确率,而在经由一个练习周期后,模子的正确率飙升至约 40.5%。固然这离 SOTA 程度还差得很远,但这展现了 GRPO 的强盛潜力。