Llama 3.2 Vision医学图像微调

MODEL ZOO Nov 25, 2024

你是否曾经想过 AI 模型如何学习理解医学图像?今天,我将带你了解一个令人兴奋的项目:微调 Meta 的 Llama 3.2 Vision 模型来分析放射图像。如果你不是技术专家,也不用担心——我会用简单的术语来解释。

1、这是什么?

想象一下,有一个可以查看 X 射线并提供详细医学描述的 AI 助手。这正是我们在这里构建的。我们正在采用 Meta 强大的 Llama 3.2 Vision 模型(一个拥有 110 亿个参数的 AI),并教它更好地理解医学图像。

2、前后对比

这就是让它变得有趣的原因:在训练之前,该模型给出了医学图像的通用、有些模糊的描述。但经过我们的微调过程后,它变得更加精确和专业,说话更像专业放射技师。

3、它是如何工作的?

这个过程就像通过例子教学生一样。我们使用一个名为“Radiology_mini”的数据集,其中包含与专家描述配对的 X 射线图像。我们反复向模型展示这些图像,它就会学会:

  • 识别特定的医学特征
  • 使用正确的医学术语
  • 像专业放射技师一样构建其响应

4、幕后的魔力

我们使用一种名为 LoRA(低秩自适应)的巧妙技术,即使在单个 GPU 上也可以训练这个庞大的模型。可以将其视为教导模型更好地完成工作,而无需重写其整个知识库。

5、结果

转变是显著的。在训练之前,该模型给出了一般的临床观察结果,例如“这张射线照片似乎是上下颌的全景图……”训练后,它提供了更有针对性和结构化的观察结果,例如“全景射线照相显示双侧动脉瘤性骨囊肿(ABC)”——对医疗专业人员来说更加精确和有用!

6、技术实施

让我们深入了解如何自己实现这一点。以下是包含代码的分步指南:

首先,安装所需的包:

pip install unsloth
export HF_TOKEN=xxxxxxxxxxxxx  # Your Hugging Face token

以下是完整的实现,分为逻辑部分:

import os
from unsloth import FastVisionModel
import torch
from datasets import load_dataset
from transformers import TextStreamer
from unsloth import is_bf16_supported
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

# Load the model
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Llama-3.2-11B-Vision-Instruct",
    load_in_4bit = True,
    use_gradient_checkpointing = "unsloth",
)

# Configure fine-tuning parameters
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True,
    finetune_language_layers   = True,
    finetune_attention_modules = True,
    finetune_mlp_modules      = True,
    r = 16,
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

# Load and prepare the dataset
dataset = load_dataset("unsloth/Radiology_mini", split = "train")
instruction = "You are an expert radiographer. Describe accurately what you see in this image."

def convert_to_conversation(sample):
    conversation = [
        { "role": "user",
          "content" : [
            {"type" : "text",  "text"  : instruction},
            {"type" : "image", "image" : sample["image"]} ]
        },
        { "role" : "assistant",
          "content" : [
            {"type" : "text",  "text"  : sample["caption"]} ]
        },
    ]
    return { "messages" : conversation }

converted_dataset = [convert_to_conversation(sample) for sample in dataset]

# Configure the trainer
FastVisionModel.for_training(model)
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    data_collator = UnslothVisionDataCollator(model, tokenizer),
    train_dataset = converted_dataset,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 30,
        learning_rate = 2e-4,
        fp16 = not is_bf16_supported(),
        bf16 = is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        dataset_num_proc = 4,
        max_seq_length = 2048,
    ),
)

# Train the model
trainer_stats = trainer.train()

# Test after training
print("\nAfter training:\n")
FastVisionModel.for_inference(model)
image = dataset[0]["image"]
instruction = "You are an expert radiographer. Describe accurately what you see in this image."

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": instruction}
    ]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
inputs = tokenizer(
    image,
    input_text,
    add_special_tokens = False,
    return_tensors = "pt",
).to("cuda")

text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache = True, temperature = 1.5, min_p = 0.1)

# Save and upload the model
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")

model.save_pretrained_merged("your-username/Llama-3.2-11B-Vision-Radiology-mini", tokenizer,)
model.push_to_hub_merged("your-username/Llama-3.2-11B-Vision-Radiology-mini", 
                        tokenizer, 
                        save_method = "merged_16bit", 
                        token = os.environ.get("HF_TOKEN"))

模型加载:我们以 4 位精度加载预先训练的 Llama 3.2 Vision 模型以节省内存。

微调配置:我们支持对各种模型组件进行微调,包括视觉层、语言层和注意模块。

数据集准备:代码将放射图像及其描述转换为模型可以理解的对话格式。

训练配置:我们使用具有特定参数的 SFTTrainer:

  • 每台设备的批次大小为 2
  • 4 个梯度累积步骤
  • 30 个最大训练步骤
  • 学习率为 2e-4
  • 线性学习率调度程序

模型保存:训练后,我们保存 LoRA 权重和模型的合并版本。

7、为什么这很重要

这种技术在医疗保健领域可能非常有价值:

  • 帮助放射科医生更有效地工作
  • 在专家有限的地区提供初步筛查
  • 培训医学生
  • 提供第二意见

8、技术栈

对于技术好奇的人,我们使用了几种现代工具来实现这一点:

  • Unsloth:一个使微调更高效的库
  • PyTorch:用于底层机器学习操作
  • Hugging Face:用于管理和共享训练后的模型

9、展望未来

这只是一个开始。随着这些模型不断改进并变得更加专业化,它们可能会成为医疗保健环境中的宝贵工具。但是,重要的是要记住,它们旨在协助而不是取代人类医疗专业人员。

10、要求和资源

要运行此代码,你需要:

  • 具有至少 48GB VRAM 的 GPU(例如 RTX A6000)
  • Python 3.8+
  • Hugging Face 帐户和 API 令牌
  • 大约 5 分钟的训练时间

请记住:虽然这项技术令人兴奋,但它仍然是协助人类专业知识的工具,而不是取代它。医疗保健的未来在于人类专业人员和 AI 助手之间的和谐合作。

11、最后提示

  • 在训练期间始终监控你的 GPU 内存使用情况
  • 从少量训练步骤开始测试你的设置
  • 确保你的训练数据质量高且标记正确
  • 跟踪前后结果以衡量改进

请随意尝试超参数并根据你的特定用例调整代码。医疗保健领域的人工智能正在迅速发展,并且有足够的创新空间!


原文链接:Fine-tuning Llama 3.2 Vision: Making AI Better at Reading Medical Images

汇智网翻译整理,转载请标明出处

Tags