微调PaliGemma优化目标检测

PaliGemma 是 Google 于 2024 年 5 月发布的大型多模态模型 (LMM)。你可以使用 PaliGemma 进行视觉问答 (VQA),检测图像上的物体,甚至生成分割蒙版。

虽然 PaliGemma 具有零样本能力(这意味着模型无需微调即可识别物体),但这种能力是有限的。Google 强烈建议对模型进行微调,以在特定领域获得最佳性能。

基础模型通常表现不佳的一个领域是医学成像。在本指南中,我们将介绍如何微调 PaliGemma 以检测 X 射线图像中的骨折。为此,我们将使用 Roboflow Universe 上可用的数据集之一。

JAX/FLAX PaliGemma 3B 有三个不同的版本,输入图像分辨率(224、448 和 896)和输入文本序列长度(分别为 128、512 和 512 个标记)不同。

为了限制 GPU 内存消耗并在 Google Colab 中启用微调,我们将在本教程中使用最小版本 paligemma-3b-pt-224。你将需要一个具有至少 12GB 可用 RAM 的 GPU 运行时,而配备 NVIDIA T4 的 Google Colab 就足够了。

为了微调 PaliGemma,我们将:

  • 以 PaliGemma JSONL 格式下载对象检测数据集;
  • 安装所需的依赖项;
  • 从 Kaggle 下载预先训练的 PaliGemma 权重和标记器;
  • 使用 JAX 微调 PaliGemma;
  • 保存我们的模型以供日后使用。
该模型将使用 Google 的低级深度学习框架 JAX 进行微调。因此,用于加载数据、训练模型和评估结果的代码片段可能很长,并且不会完整包含在这篇博文中。完整代码可在随附的笔记本中找到。

事不宜迟,让我们开始吧!

1、下载对象检测数据集

要微调 PaliGemma 进行对象检测,你需要一个 PaliGemma JSONL 格式的数据集。此格式通常不用于训练像 YOLO 这样的传统计算机视觉模型,但通常用于训练语言模型。JSONL 格式的数据集的每一行都是一个单独的 JSON 对象,就像单个记录的列表一样。

在我们的例子中,每个记录都包含关联图像的名称、将传递给模型的前缀(提示)以及来自模型的后缀(预期响应)。以下是我们数据集中的一个对象:

{'image': 'n_0_2513_png_jpg.rf.1f679ff5dec5332cf06f6b9593c8437b.jpg', 'prefix': 'detect fracture', 'suffix': '<loc0390><loc0241><loc0472><loc0440> fracture'}

在提示中,请注意关键字 detect后跟我们要“检测”的类列表,以分号分隔。预期的检测结果由 <loc{Y1}><loc{X1}><loc{Y2}><loc{X2}>中的边界框和类名描述。值X1、Y1、X2和Y2描述边界框的位置,标准化为1024x1024的图像大小。每个值应该有4位数字;如果坐标较短,则用零填充。

Roboflow 完全支持 PaliGemma JSONL 格式,可用于导出 Roboflow Universe 上的 250,000 多个数据集中的任何一个。

首先,安装下载和解析数据集所需的依赖项:

pip install roboflow supervision

对于本指南,我们将使用 Roboflow API 密钥下载骨折检测数据集:

from google.colab import userdata
from roboflow import Roboflow

ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')

rf = Roboflow(api_key=ROBOFLOW_API_KEY)
project = rf.workspace("srinithi-s-tzdkb").project("fracture-detection-rhud5")
version = project.version(4)
dataset = version.download("PaliGemma")

在开始微调之前,让我们通过可视化数据集中的一个示例来确保数据集的格式正确:

from PIL import Image
import json

first = json.loads(open(f"{dataset.location}/dataset/_annotations.train.jsonl").readline())
print(first)

image = Image.open(f"{dataset.location}/dataset/{first.get('image')}")
CLASSES = first.get('prefix').replace("detect ", "").split(" ; ")
detections = from_pali_gemma(first.get('suffix'), image.size, CLASSES)

sv.BoundingBoxAnnotator().annotate(image, detections)

现在我们知道我们的标注已正确显示,我们可以设置我们的 Python 环境并开始微调。本节中的大部分代码来自 PaliGemma 团队发布的官方 Google Colab

2、模型设置

为了训练用于对象检测的 PaliGemma 模型,我们将使用 Google Research 维护的 big_vision 项目。我们可以使用以下代码安装此项目:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

安装 big_vision 后,接下来需要下载 PaliGemma 模型权重。这些权重可在 Kaggle 上获得。你需要一个 Kaggle 帐户才能下载权重。必须同意 Kaggle 中的 PaliGemma 服务条款才能使用模型权重。

设置 Kaggle 帐户并同意服务条款后,可以使用以下代码下载 PaliGemma 权重:

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

import os
import kagglehub

MODEL_PATH = "./PaliGemma-3b-pt-224.f16.npz"
if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  # Note: kaggle archive contains the same checkpoint in multiple formats.
  # Download only the float16 model.
  MODEL_PATH = kagglehub.model_download('google/PaliGemma/jax/PaliGemma-3b-pt-224', MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./PaliGemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/PaliGemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

3、训练 PaliGemma 模型进行对象检测

下载模型权重后,我们现在可以在自定义对象检测数据集上训练 PaliGemma 模型。此步骤的代码很长,因此本指南将不包含代码。按照随附的笔记本获取训练模型所需的所有代码。

训练模型需要遵循的步骤是:

  • 导入所有必需的依赖项
  • 使用 ml_collections 库构建模型。
  • 将模型权重加载到 RAM 中以供训练使用。
  • 将参数移动到 GPU/TPU 内存以供训练使用。
  • 定义图像和标记的预处理函数。
  • 使用 PaliGemma jsonl 格式定义一个训练循环,该循环将遍历所有训练和验证示例。
  • 运行具有指定学习率和示例数量的训练循环来微调模型。

所有这些步骤都记录在随本文附带的 Colab 笔记本中。

在我们的 Colab 中,我们将批处理大小设置为 8,将学习率设置为 0.01,并将训练和评估步骤的数量定义为:

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.01

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 8

有了训练好的模型,我们现在可以测试它了。

4、测试经过微调的对象检测模型

在我们的 Colab 笔记本中,我们声明了一个名为 make_predictions 的函数,该函数接受一个遍历图像并对每个图像运行推理的函数。

我们可以使用这个函数来测试我们经过微调的对象检测模型:

html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))

以下是我们的模型在项目的验证数据集上运行时产生的部分结果:

此图像中有来自验证集的图像,其中粉色边界框与模型的检测结果相对应,右侧的文本标签告诉我们识别出的类别(“骨折”)。

可以使用以下代码保存模型以供日后使用:

flat, _ = big_vision.utils.tree_flatten_with_names(params)
with open("/content/fine-tuned-PaliGemma-3b-pt-224.f16.npz", "wb") as f:
  np.savez(f, **{k: v for k, v in flat})

5、上传模型权重

准备好模型权重后,可以使用 Roboflow Inference 将其部署到硬件上。

可以使用以下代码将模型权重上传到 Roboflow:

import roboflow
rf = Roboflow(api_key="API_KEY")
project = rf.workspace("workspace-id").project("project-id")
version = project.version(VERSION)
version.deploy(model_type="paligemma-3b-pt-224", model_path="/content/paligemma-lora")

以上,替换:

  • API_KEY 为你的 Roboflow API 密钥。
  • workspace-id 和 project-id 为你的工作区和项目 ID。
  • VERSION 为你的项目版本。

如果不使用我们的笔记本,请将 /content/paligemma-lora 替换为你保存模型权重的目录。

当运行上述代码时,模型将上传到 Roboflow。模型需要几分钟才能处理完毕,然后才能使用。

6、部署 PaliGemma 模型

当你的模型准备就绪后,可以在任何想要部署模型的设备上从 Roboflow 下载它。为此,可以使用开源计算机视觉推理服务器 Roboflow Inference。

首先,安装包:

pip install inference

然后,创建一个新的 Python 文件并添加以下代码:

import os
from inference import get_model
from PIL import Image
import json

lora_model = get_model("model-id/version-id", api_key="KEY")

image = Image.open("image.jpg")
response = lora_model.infer(image)
print(response)

以上,替换:

  • model-id 为你的 Roboflow 模型 ID;
  • version-id 为你的项目版本,以及;
  • KEY 为你的 Roboflow API 密钥。

当运行上述代码时,你将收到您提供的图像上的模型预测值。

然后,你可以使用 supervision 包可视化模型的结果。首先,安装包:

pip install supervision

然后,将以下代码添加到你的 Python 文件中:

import supervision as sv

detections = sv.Detections.from_inference(response)

box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(
    scene=image.copy(),
    detections=detections
)

sv.plot_image(image=annotated_frame, size=(16, 16))

7、结束语

PaliGemma 是 Google 开发的多模态视觉模型。PaliGemma 可用于识别图像中物体的位置,并识别与图像中特定物体相对应的分割蒙版。

在本指南中,我们介绍了如何使用自定义数据集对 PaliGemma 进行物体检测微调,并参考了改编自 Google 官方 PaliGemma 微调笔记本的笔记本。

我们从 Roboflow Universe 下载了一个兼容的数据集,目视检查以确保注释以 PaliGemma 格式正确存储,然后在 Google Colab 上运行训练作业。然后,我们使用项目的相应验证数据集测试了我们的模型,取得了很好的效果。


原文链接:How to Fine-tune PaliGemma for Object Detection Tasks

汇智网翻译整理,转载请标明出处