用GRPO算法训练医疗AI模型

用GRPO算法训练医疗AI模型

大型语言模型(LLMs)与医疗保健的交叉点带来了令人兴奋的机会,但也带来了独特的挑战。在本教程中,我们将探讨如何使用分组相对策略优化(GRPO)——一种最近由DeepSeek团队引入的有前途的新强化学习技术——来适应阿里巴巴的Qwen-3B模型以用于医学推理。

为什么这很重要:

  • 🏥 患者安全第一:医学AI中的幻觉可能是危险的。
  • 💡 领域专业化:通用LLMs难以处理临床推理。
  • 效率:我们的3B参数模型可以在消费级GPU上运行。

推理模型如O3和DeepSeek R1在许多具有挑战性的基准测试中显示了前所未有的改进。它们改变了监督微调的趋势,转向实际的强化学习(RL)。我们在深度学习领域的许多突破都来自RL,例如AlphaGo,因为模型能够通过与不同的现实场景互动来学习,而这些场景在监督微调中往往难以提供示例。

DeepSeek R1在几个关键基准上的表现[1]

如果你想了解更多关于推理模型或更多历史细节,我强烈推荐Maarten的文章[2]。DeepSeek工作的美妙之处在于他们实现了一个实用的框架,用于使用GRPO对LLM进行微调。根据Maarten的文章:

这个算法背后的直觉是,它使所有导致正确或错误答案的选择更可能或更不可能。这些选择可以是令牌集也可以是推理步骤。

正如下面的图片所示:目标是激励模型生成响应,使其在正确的**块中以及我们能够轻松验证的最终正确答案中都能产生良好的结果(如数学问题)。

DeepSeek-R1-Zero使用的RL管道[2]

好了,背景知识就到这里,让我们开始动手吧。本文使用的代码作为colab笔记本提供,你可以轻松地使用T4免费资源运行。

1、安装Unsloth和TRL

开源软件已经取得了很大进展——在这个教程中,我们将使用两个出色的开源库:

  • Unsloth : 一个帮助我们从GPU中尽可能多地提取内存并提高训练性能的库。
  • TRL: 来自Hugging Face的一个开源库,帮助我们实现GRPO。

我们还将使用Qlora技术,帮助我们以更高效的方式微调模型。如果你想了解更多关于Qlora的信息,我强烈推荐Sebastian的文章[3]

!pip install unsloth vllm  # 内存高效的训练和推理  
!pip install trl@git+https://github.com/huggingface/trl  # GRPO实现
from unsloth import FastLanguageModel, PatchFastRL  
PatchFastRL("GRPO", FastLanguageModel)

2、下载并初始化模型

我们将首先下载模型,并利用50%的GPU容量以及vLLM推理来加速使用Qlora的GRPO训练。

from unsloth import is_bfloat16_supported  
import torch  
max_seq_length = 2048 # 可以增加以支持更长的推理跟踪  
lora_rank = 64 # 较大的rank表示更智能但更慢  
  
model, tokenizer = FastLanguageModel.from_pretrained(  
    model_name = "Qwen/Qwen2.5-3B-Instruct",  
    max_seq_length = max_seq_length,  
    load_in_4bit = True, # False for LoRA 16bit  
    fast_inference = True, # 启用vLLM快速推理  
    max_lora_rank = lora_rank,  
    gpu_memory_utilization = 0.5, # 减少如果出现内存不足  
)  
  
model = FastLanguageModel.get_peft_model(  
    model,  
    r = lora_rank, # 选择任何大于0的数字!建议8, 16, 32, 64, 128  
    target_modules = [  
        "q_proj", "k_proj", "v_proj", "o_proj",  
        "gate_proj", "up_proj", "down_proj",  
    ], # 如果出现内存不足则移除QKVO  
    lora_alpha = lora_rank,  
    use_gradient_checkpointing = "unsloth", # 启用长上下文微调  
    random_state = 3407,  
)

关键选择

  • 量化:启用16/24GB GPU训练(兼容T4/A10)
  • LoRA Rank 64:平衡性能与内存
  • vLLM集成:生成速度提高50%

3、数据策略:医学推理鸡尾酒

我们使用Hugging Face的interleave_datasets混合了三个关键数据集:

PubMedQA(占总数据的70%):

  • 临床问答,答案为yes/no/maybe
  • 过滤到<1024个token以提高内存效率

GSM8K

  • 数学文字问题,以保持数值推理能力

Health Benchmarks

  • 超过50个医学专科的多项选择题
  • 包括心脏病学到疫苗接种等类别

小贴士:权重应反映数据集复杂性——PubMedQA获得3倍更多的曝光以处理其复杂性。我们这里没有使用任何等待时间,但我们打乱了数据集,由于我们有三倍的PubMedQA样本,因此模型有三倍的机会展示这些示例。

import re  
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets  
  
# 加载和准备数据集  
SYSTEM_PROMPT = """  
响应格式如下:  
<reasoning>  
...  
</reasoning>  
<answer>  
...  
</answer>  
"""  
  
XML_COT_FORMAT = """\  
<reasoning>  
{reasoning}  
</reasoning>  
<answer>  
{answer}  
</answer>  
"""  
  
def extract_xml_answer(text: str) -> str:  
    answer = text.split("<answer>")[-1]  
    answer = answer.split("</answer>")[0]  
    return answer.strip()  
  
def extract_hash_answer(text: str) -> str | None:  
    if "####" not in text:  
        return None  
    return text.split("####")[1].strip()  
  
# 取消注释中间消息以进行一次提示  
def get_datasets(split = "train") -> Dataset:  
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore  
    data = data.map(lambda x: { # type: ignore  
        'prompt': [  
            {'role': 'system', 'content': SYSTEM_PROMPT},  
            {'role': 'user', 'content': x['question']}  
        ],  
        'answer': extract_hash_answer(x['answer']),  
        'db_set':'gsm8k'  
    }) # type: ignore  
    data = data.remove_columns(['question'])  
      
    data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # 两倍于其他数据集  
    data_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # 避免长跟踪  
    data_qa = data_qa.map(lambda x: { # type: ignore  
        'prompt': [  
            {'role': 'system', 'content': SYSTEM_PROMPT},  
            {  
                "role": "user",  
                "content": "Given the scientific context below:\n" +   
                          "\n".join(x['context']['contexts']) +   
                          "\n\nAnswer the following question:\n" +  
                          x['question'] +   
                          " with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."  
            },  
        ],  
        'answer': x['final_decision'],  
        'db_set': 'pubmedqa'  
    }) # type: ignore  
    data_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])  
      
      
    categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']  
    data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])  
    data_mc = data_mc.map(lambda x: { # type: ignore  
        'prompt': [  
            {'role': 'system', 'content': SYSTEM_PROMPT},  
            {  
                "role": "user",  
                "content": "\n\nAnswer the following question:\n" +  
                          x['Questions'] +   
                          "\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."  
            },  
        ],  
        'answer': x['Answers'],  
        'db_set': 'med_mc'  
    }) # type: ignore  
    data_mc = data_mc.remove_columns(['Answers', 'Questions'])  
      
    dataset = concatenate_datasets([data, data_qa, data_mc])  
    return dataset

4、秘密武器:奖励工程

我们的多奖励系统既教授推理结构也教授医学准确性(请参阅笔记本[4]以了解详细的奖励函数):

def correctness_reward(responses, answers):  
    # 给予完全匹配2.0分,部分匹配1.0分  
    return [2.0 if match else (1.0 if partial else 0.0)...]  
  
def format_reward(completions):  
    # 强制执行<reasoning>...</answer>结构  
    return [0.5 if re.match(XML_PATTERN) else 0.0...]

奖励层次

  1. 准确性(50%权重):与真实情况一致
  2. 格式化(30%):XML风格的推理跟踪
  3. 中间检查(20%):有效答案类型
类比:就像教导一位医学住院医生一样——既要表扬诊断准确度也要表扬适当的文档记录。

5、GRPO训练配置

这些参数主要是猜测工作,尚未优化,但在我的初步实验中似乎效果很好。您可以根据自己的用例进行调整和实验。

from trl import GRPOConfig, GRPOTrainer  
training_args = GRPOConfig(  
    use_vllm = True, # 使用vLLM进行快速推理!  
    learning_rate = 5e-6,  
    adam_beta1 = 0.9,  
    adam_beta2 = 0.99,  
    weight_decay = 0.1,  
    warmup_ratio = 0.1,  
    lr_scheduler_type = "cosine",  
    optim = "adamw_8bit",  
    logging_steps = 1,  
    bf16 = is_bfloat16_supported(),  
    fp16 = not is_bfloat16_supported(),  
    per_device_train_batch_size = 1,  
    gradient_accumulation_steps = 1, # 增加到4以获得更平滑的训练  
    num_generations = 6, # 出现内存不足时减少  
    max_prompt_length = 1024,  
    max_completion_length = 1024,  
    #num_train_epochs = 1, # 设置为1以进行完整的训练  
    max_steps = 750,  
    save_steps = 100,  
    max_grad_norm = 0.1,  
    report_to = "none", # 可以使用Weights & Biases  
    output_dir = "outputs",  
)  
  
trainer = GRPOTrainer(  
    model = model,  
    processing_class = tokenizer,  
    reward_funcs = [  
        xmlcount_reward_func,  
        soft_format_reward_func,  
        strict_format_reward_func,  
        int_reward_func,  
        correctness_reward_func,  
    ],  
    args = training_args,  
    train_dataset = train_dataset,  
    eval_dataset=test_dataset,  
)  
trainer.train()

就这样——由于我们的奖励函数质量很高——您应该很快开始看到奖励函数跳得很高(这是RL微调的第一步)。

模型开始发展推理能力

让我们看看模型在训练过程中生成的一些示例答案:

Prompt:

Given the scientific context below:

We have previously shown the detrimental effects of 19 sub-erythemal exposures to daily ultraviolet radiation (DUVR, which mimics non-extreme exposure conditions), delivered over 4 weeks to volunteers. This source had UVA (320-400 nm) to UVB (290-320 nm) irradiance ratio of 25, instead of that close to 10 that is typically the case with solar-simulated radiation (SSR) that represents summer global sunlight with a clear sky and quasi-zenith solar irradiance.
Here, we report on an extension of this previous study, in which we evaluated the photoprotection afforded by a broad-spectrum daily-care product with a low-sun protection factor (SPF 8, UVA-PF 7 and 3* rated UVA protection). We assessed cellular and molecular markers of photodamage that are relevant to skin cancer and photoageing.
This study shows that biological effects of repeated exposure to DUVR can be prevented by a broad-spectrum daily-care product and that the level of protection afforded varies with the studied endpoint.

Answer the following question:
Does a broad-spectrum sunscreen prevent cumulative damage from repeated exposure to sub-erythemal solar ultraviolet radiation representative of temperate latitudes? with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering. 

Model Response:

<reasoning>
The context provided discusses a study where repeated exposures to sub-erythemal UV radiation (DUVR) can cause detrimental effects over a period of 4 weeks. It also mentions that a broad-spectrum daily-care product with an SPF 8 and UVA protection factor with a 7 and 3* rating was used. The study aims to evaluate whether this broad-spectrum sunscreen can prevent cumulative damage from such repeated exposures. The mention of repeated exposure and the presence of a broad-spectrum product strongly suggests that the sunscreen can prevent cumulative damage. However, the actual effectiveness of the sunscreen in preventing cumulative damage is not explicitly stated in the provided context. Given the experimental setup and the corrective action (broad-spectrum product) taken, it can be inferred that the answer is likely 'yes'.

The sunscreen contains SPF 8, which means it provides protection to around 92% of the UVB radiation, and it has a UVA protection factor (UVA-PF) of 7, which likely provides varying degrees of protection to different wavelengths, including those of UVA. Such broad-spectrum protection would effectively prevent the cumulative damage associated with prolonged exposure to sub-erythemal UV radiation.

</reasoning>

<answer>
yes
</answer>

如你所见,模型开始将思考组织成结构化的推理块——现在你可以想象我们可以使用这样的技术显著改善专门模型的RAG。我们甚至可以教模型编写SQL代码、医学术语规范化等等,潜力无限!

完整笔记本在此!


原文链接:Build Your Own Medical Mini-DeepSeek R1 with Reinforcement Learning

汇智网翻译整理,转载请标明出处