Donut模型微调实用指南

MODEL ZOO Jan 7, 2025
“精通始于基础的结束。”

我个人认为,这是机器学习中的一句口头禅,尤其是在使用 Donut(文档理解转换器)等模型时。

如果你正在阅读本文,可能知道预训练模型非常强大,但让我告诉你,它们不是魔杖。

我曾参与过一些项目,其中现成的模型在处理特定领域的文档(例如高度风格化的发票或多语言合同)时失败了。

事情是这样的:Donut 在 OCR(光学字符识别)和信息提取任务方面表现出色,尤其是在处理复杂的文档布局时。

但是,当你的数据集与预训练数据集有很大差异时,微调就变得至关重要。我在处理法律文档流程时亲眼目睹了这一点。

预训练的 Donut 模型错过了法律术语和格式特有的关键注释,但对其进行微调带来了显着的改进。

在本指南中,我将向你介绍我用来微调 Donut 模型的确切步骤,以便你可以根据特定领域的需求对其进行调整。

在本指南结束时,你不仅会拥有一个经过微调的模型,而且还会有信心为未来的项目定制它。

1、先决条件

当我开始使用 Donut 模型时,我意识到这不仅仅是运行脚本,而是准备工作。以下是你开始使用所需的条件:

1.1 技术技能

你应该已经熟悉 Python、PyTorch 和 Hugging Face Transformers。如果你之前已经在生产中部署过模型,那么就没问题了。如果没有,请不要担心——我将指导你完成实际方面。

1.2 环境设置

根据我的经验,正确的设置可以为你节省数小时的调试时间。以下是我推荐的:

硬件:具有至少 16GB 内存的 GPU 效果最佳。我个人使用了 NVIDIA RTX 3090,它完美地完成了任务。AWS 或 Google Colab Pro+ 等云设置也是不错的选择。

软件:

  • Python (≥3.8)
  • PyTorch (≥1.12)
  • Hugging Face Transformers (≥4.30)

这是一个快速安装脚本:

# Install Python dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
pip install transformers datasets accelerate

# Install additional libraries
pip install pillow albumentations

1.3 数据集要求

现在,让我们谈谈数据。微调的好坏取决于你输入的数据集。在一个项目中,我处理了多种语言的扫描收据。准备这个数据集是一个挑战,但它教会了我一些东西。

  • 结构:你的数据集应包括 JSON 格式的文档图像和相应的标注。 Donut 需要类似 COCO 的结构,其中每个标注都映射到文档中的特定字段。
  • 预处理工具:用于初始 OCR 的 Tesseract 和用于布局分析的 LayoutLMv3 等工具可以在预处理过程中起到救星作用。就我个人而言,我依靠 Python 库(如 Pillow)进行图像处理,使用 json 进行标注处理。

以下是构建数据集的代码片段:

import os
import json
from PIL import Image

# Define dataset paths
image_dir = "./dataset/images"
annotations_path = "./dataset/annotations.json"

# Verify dataset format
for file_name in os.listdir(image_dir):
    img_path = os.path.join(image_dir, file_name)
    if img_path.endswith(".png") or img_path.endswith(".jpg"):
        try:
            # Load image
            img = Image.open(img_path)
            print(f"Loaded {file_name}: {img.size}")
        except Exception as e:
            print(f"Error loading {file_name}: {e}")

# Example annotation structure
annotation_example = {
    "image_id": "0001",
    "annotations": [
        {"field": "Invoice Number", "value": "12345", "bbox": [100, 200, 300, 400]},
        {"field": "Date", "value": "2025-01-05", "bbox": [400, 200, 600, 400]},
    ],
}
with open(annotations_path, "w") as f:
    json.dump(annotation_example, f)

这些都是先决条件。设置好环境并准备好数据集后,就可以开始进行微调过程了,我们将在下文中介绍。相信我,当你的模型开始提供满足你需求的结果时,你在这里付出的努力将得到回报。

2、准备数据集

“好的数据胜过好的算法。”

我在使用机器学习模型的工作中一次又一次地看到了这一点,尤其是像 Donut 模型这样微妙的东西。

当我第一次对 Donut 进行微调时,数据集准备阶段改变了游戏规则。以下是如何以正确的方式设置数据,避免我遇到的陷阱。

2.1 数据集格式

我做的第一件事是确保我的数据集的格式与 Donut 兼容。该模型需要 JSON 结构,并为文档图像提供类似 COCO 的标注。

这种格式提供了图像及其相应的标签,使模型易于解析和学习。

这是我用来验证数据集兼容性的 Python 代码片段:

import json

# Load dataset annotation file
annotation_file = "./annotations.json"

with open(annotation_file, "r") as f:
    annotations = json.load(f)

# Check if annotations match the expected structure
if "images" in annotations and "annotations" in annotations:
    print("Dataset structure is valid.")
else:
    raise ValueError("Dataset format is incompatible with Donut.")

如果你的数据集不遵循此格式,则需要对其进行转换。我不得不为我的一个项目编写自定义脚本,以将“发票号”或“总金额”等字段映射到正确的 JSON 格式。相信我,在这里投入时间将为你节省训练期间的麻烦。

2.2 预处理步骤

事情是这样的:预处理不仅仅是清理数据——它是为了最大限度地提高模型性能而对其进行调整。

在一个项目中,我必须处理文本歪斜或嘈杂的多语言发票。关键是要提前处理这些变化。

  • 降噪:高斯模糊等技术可以减少图像伪影。
  • 多语言处理:确保你的标注和数据支持日语或阿拉伯语等语言的 Unicode。

以下是将图像转换为张量并对其进行规范化的实际示例:

from PIL import Image
import torch
from torchvision import transforms

# Define transformation pipeline
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
    transforms.Resize((256, 256)),               # Resize to uniform dimensions
    transforms.ToTensor(),                       # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))         # Normalize
])

# Load and preprocess an image
image_path = "./dataset/images/sample.jpg"
image = Image.open(image_path)
tensor_image = transform(image)

print(f"Image Tensor Shape: {tensor_image.shape}")

2.3 数据增强

我学到的一件事是,增强文档图像可以显著提高泛化能力。

我使用过 Albumentations 等库来添加噪点、旋转图像和裁剪边缘 — — 这些都是现实世界中常见的变化。

以下是如何实现增强管道:

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

# Define augmentation pipeline
augmentation = A.Compose([
    A.Rotate(limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.GaussianBlur(p=0.3),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

# Apply augmentation to an image
image_path = "./dataset/images/sample.jpg"
image = cv2.imread(image_path)
augmented = augmentation(image=image)
augmented_image = augmented["image"]

print(f"Augmented Image Shape: {augmented_image.shape}")

我使用了这种精确的方法在一个包含财务文件的项目中,它明显减少了过度拟合。

2.4 拆分数据

正确拆分数据集至关重要。我个人曾遇到过不恰当的拆分导致评估指标出现偏差的情况,所以现在我总是确保训练、验证和测试集之间的平衡分布。

以下是使用 sklearn 的简单而可靠的方法:

from sklearn.model_selection import train_test_split

# Example dataset
images = ["img1.jpg", "img2.jpg", "img3.jpg"]  # Replace with your actual dataset
annotations = ["ann1.json", "ann2.json", "ann3.json"]

# Split data into train, validation, and test sets
train_imgs, val_test_imgs, train_anns, val_test_anns = train_test_split(
    images, annotations, test_size=0.4, random_state=42
)
val_imgs, test_imgs, val_anns, test_anns = train_test_split(
    val_test_imgs, val_test_anns, test_size=0.5, random_state=42
)

print(f"Training set: {len(train_imgs)} images")
print(f"Validation set: {len(val_imgs)} images")
print(f"Test set: {len(test_imgs)} images")

我遵循的一个很好的经验法则是 70% 的训练、15% 的验证和 15% 的测试。它在训练稳健性和评估可靠性之间取得了适当的平衡。

这就是准备数据集的全部内容!通过正确格式化、预处理和拆分数据,你将为有效的微调奠定基础。

相信我,这项基础工作值得每一分钟——这就是在现实场景中挣扎的模型与表现出色的模型的区别所在。接下来:微调 Donut 模型!

3、微调 Donut 模型

微调 Donut 不仅仅是一项技术任务;这是一个过程,从模型选择到超参数调整,你做出的每个决定都会影响结果。

我已经经历过多次这样的旅程,让我告诉你——这是艺术与科学的融合。让我们一步一步来分析。

3.1 选择预训练模型

当我第一次开始使用 Donut 时,最关键的问题是:我应该使用哪个预训练检查点?

Donut 在 Hugging Face 模型中心提供了多个版本,每个版本都针对 OCR 或信息提取等特定任务进行了优化。

这是我所寻找的:

  • 领域相关性:如果你的任务涉及发票,请从针对财务文件进行训练的模型开始。
  • 模型大小:像 Donut-Small 这样的较大模型可以处理复杂的布局,但需要更多的计算能力。

你可以在此处浏览 Hugging Face 中心上的预训练模型。就我个人而言,我发现 naver-clova-ix/donut-base 是通用任务的绝佳起点。

3.2 模型配置

现在到了有趣的部分——调整模型以适合你的数据集。在我的一个项目中,我注意到默认学习率过于激进,导致训练早期过度拟合。调整超参数会产生很大的不同。

以下是修改配置的方法:

from transformers import DonutConfig

# Load pre-trained configuration
config = DonutConfig.from_pretrained("naver-clova-ix/donut-base")

# Adjust hyperparameters
config.learning_rate = 5e-5  # Reduce learning rate for fine-tuning
config.batch_size = 16       # Set batch size based on GPU capacity
config.num_train_epochs = 10 # Number of epochs

print(config)

3.3 加载模型

使用 Hugging Face 加载预训练模型非常简单。但我学到了一点:在微调之前,一定要在样本输入上测试模型。这种健全性检查为我节省了数小时的调试时间。

以下是初始化 Donut 的代码片段:

from transformers import DonutForDocumentClassification

# Load the pre-trained model
model = DonutForDocumentClassification.from_pretrained(
    "naver-clova-ix/donut-base", 
    config=config
)

# Check model output on a sample input
sample_input = {"pixel_values": torch.rand(1, 3, 256, 256)}  # Replace with your preprocessed data
outputs = model(**sample_input)
print(outputs)

3.4 训练管道

你可能想知道,“微调 Donut 的最佳方法是什么?”我个人认为 Hugging Face 的 Trainer API 是一个改变游戏规则的工具。

它简化了训练,同时允许你自定义关键组件,如损失函数和评估指标。

以下是一个详细示例:

from transformers import TrainingArguments, Trainer
from datasets import load_dataset

# Load dataset
dataset = load_dataset("your_custom_dataset")

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    num_train_epochs=10,
    evaluation_strategy="steps",
    save_steps=500,
    logging_dir="./logs"
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=None,  # Donut doesn't use tokenizers
)

# Start training
trainer.train()

我喜欢这种方法的地方在于它如何记录所有内容(从损失值到评估指标),以便你可以实时监控进度。

3.5 处理大型数据集

如果你正在处理大量数据集,这里有一个我通过艰苦努力学到的技巧:不要尝试一次将所有内容加载到内存中。

相反,使用梯度累积和混合精度训练等技术来优化资源使用。

以下是如何实现梯度累积:

training_args = TrainingArguments(
    output_dir="./results",
    gradient_accumulation_steps=4,  # Accumulate gradients over 4 steps
    fp16=True,                      # Enable mixed precision
    per_device_train_batch_size=2,  # Reduce batch size to fit in memory
    num_train_epochs=10,
    learning_rate=5e-5,
)

对于分布式训练,Hugging Face 的 accelerate库可以帮助你跨多个 GPU 进行扩展。相信我,这在对大型数据集进行微调时非常有价值。

微调就是将通用模型变成真正适合你需求的模型。

这个过程一开始可能让人望而生畏,但一旦你看到结果——一个比其他任何东西都更能理解你的数据的模型——你就会觉得每一刻的努力都是值得的。

接下来,我们将深入评估你的微调模型,并确保它已准备好投入生产。

4、评估微调模型

当我第一次开始评估我的微调模型时,我认为这个过程很简单:运行一些测试,获取数字,然后继续。

但问题是:评估不仅仅是数字——它是为了了解你的模型在做什么是正确的,更重要的是,它在哪里出了问题。

4.1 要使用的指标

根据我的经验,选择正确的指标可以成就或破坏你的评估过程。对于涉及文档理解的任务,这些是我依赖的首选指标:

  • 字符错误率 (CER):非常适合衡量 OCR 性能。它告诉你模型相对于基本事实有多少个字符是错误的。
  • 单词错误率 (WER):当重点关注文本级准确性时,这是我使用的指标。​​
  • BLEU 分数:非常适合结构化文本输出很重要的任务。

以下是计算 CER 的快速代码片段:

def calculate_cer(prediction, ground_truth):
    from jiwer import cer
    return cer(ground_truth, prediction)

# Example
pred = "Hello Wrold"
gt = "Hello World"
print(f"Character Error Rate: {calculate_cer(pred, gt)}")

4.2 评估管道

在评估经过微调的 Donut 模型时,我喜欢自动化整个管道。

手动对验证集进行测试非常耗时,而且容易出错。这是我使用的管道:

from transformers import pipeline
import json

# Load model
model_path = "./results/checkpoint"
donut_pipeline = pipeline("document-question-answering", model=model_path)

# Load validation dataset
with open("./validation_data.json", "r") as f:
    validation_data = json.load(f)

# Evaluate model
results = []
for data in validation_data:
    image = data["image"]
    ground_truth = data["annotations"]
    prediction = donut_pipeline(image)
    cer_score = calculate_cer(prediction, ground_truth)
    results.append({"prediction": prediction, "cer": cer_score})

# Print average CER
avg_cer = sum(r["cer"] for r in results) / len(results)
print(f"Average CER: {avg_cer}")

4.3 可视化

数字很棒,但没有什么比看到模型在实际示例上的表现更好。当我调试或展示结果时,我会在预测和基本事实之间生成视觉比较。

以下是可视化方法:

import matplotlib.pyplot as plt
from PIL import Image

# Load an image and its annotations
image_path = "./sample_image.jpg"
prediction = "Total: $123.45"
ground_truth = "Total: $123.40"

# Plot the image with annotations
image = Image.open(image_path)
plt.imshow(image)
plt.title(f"Prediction: {prediction}\nGround Truth: {ground_truth}")
plt.axis("off")
plt.show()

我发现这对于向利益相关者解释模型行为非常有用——或者当事情没有按预期进行时,仅向我自己解释!

5、调试和优化

如果说我学到了一件事,那就是:调试模型既是一门艺术,也是一门科学。我参与的每个项目都存在不少问题,下面是我解决这些问题的方法。

5.1 常见的挑战

过拟合:在一个项目中,我的模型在训练集上表现完美,但在验证数据上却惨遭失败。罪魁祸首是什么?过度拟合。

解决方法:使用 dropout、数据增强或提前停止等技术。

欠拟合:当你的模型没有从数据中学到足够的知识时,就会发生这种情况。

解决方法:尝试更深层次的架构或训练更多时期。

嘈杂的数据:嘈杂的标注不止一次破坏了我的训练过程。

解决方法:清理数据集并使用工具标记异常值。

5.2 优化技术

你可能想知道,“微调超参数的最佳方法是什么?”我已经成功使用了动态学习率计划和权重衰减。以下是如何实现这些的:

from transformers import get_scheduler

# Define optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=500, num_training_steps=10000
)

5.3 错误分析

当出现问题时(这种情况经常发生),我会依靠工具深入挖掘错误。例如,生成混淆矩阵对于了解模型遇到的问题非常有帮助。

以下是生成详细错误日志的方法:

import pandas as pd

# Example error analysis
error_log = [{"image": "img1.jpg", "pred": "Total: $123", "gt": "Total: $124"}]

# Save as CSV
df = pd.DataFrame(error_log)
df.to_csv("error_log.csv", index=False)
print("Error log saved!")

调试和优化不仅仅是流程中的步骤 — — 它们是您真正了解模型的地方。

一旦你解决了常见的挑战、微调了超参数并分析了错误,就会看到你的模型表现得比想象的更好。

接下来,让我们谈谈如何部署经过微调的模型,以便它在现实世界中大放异彩!

6、部署经过微调的模型

这是我逐渐意识到的一个事实:训练模型只是旅程的一半;部署是奇迹发生的地方。

我曾经参与过一些项目,其中经过微调的模型在受控测试中表现良好,但在实际应用中却表现不佳。

这就是为什么这一步至关重要——让我们确保你经过微调的 Donut 模型已准备好部署。

6.1 导出模型

在部署之前,你需要保存经过微调的模型。我经常使用 Hugging Face 的 save_pretrained 方法,因为它可以确保跨各种应用程序的兼容性。

以下是导出模型的代码片段:

from transformers import DonutForDocumentClassification

# Save the fine-tuned model
output_dir = "./fine_tuned_donut"
model.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")

这种方法的优点在于它同时打包了模型权重和配置,使重新加载变得无缝。

6.2 集成到应用程序中

在部署模型时,我发现 FastAPI 是救星。它轻量级、易于设置,并且非常适合为 ML 模型提供服务。这是我为文档分析构建的简单 API:

from fastapi import FastAPI, File, UploadFile
from transformers import pipeline
from PIL import Image

app = FastAPI()

# Load your fine-tuned model
donut_pipeline = pipeline("document-question-answering", model="./fine_tuned_donut")

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    # Read and process the uploaded file
    image = Image.open(file.file)
    prediction = donut_pipeline(image)
    return {"prediction": prediction}

运行此应用程序,你将拥有一个可用的 API 来为您的模型提供服务。在我的一个项目中,此 API 成为一个大型系统的支柱,该系统可自动为客户处理发票。

6.3 推理优化

情况是这样的:部署模型是一回事,但优化模型以进行实时推理又是另一回事。根据我的经验,ONNX 和 TensorRT 等工具可以大大减少延迟。

以下是我将 PyTorch 模型转换为 ONNX 以加快推理速度的方法:

import torch

# Export to ONNX
dummy_input = torch.rand(1, 3, 256, 256)  # Replace with your input size
onnx_path = "./donut_model.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    input_names=["input"],
    output_names=["output"],
)
print(f"Model exported to {onnx_path}")

通过这种方式,我成功部署了能够以最小延迟处理每分钟数千个请求的模型。

7、高级主题

如果你和我一样,那么你总是在寻找突破界限的方法。以下是我用来将 Donut 模型提升到新水平的一些高级技术。

7.1 自定义分词器

Donut 不依赖于传统的分词器,但有时你可能需要调整其输入处理以适应特定领域的任务。

例如,在一个项目中,我需要模型来处理特定的发票布局。自定义分词器至关重要。

以下是调整 Donut 分词器的示例:

from transformers import AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("naver-clova-ix/donut-base")

# Customize tokenizer
special_tokens = ["[CUSTOM_FIELD]", "[ANOTHER_FIELD]"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
model.resize_token_embeddings(len(tokenizer))
print("Tokenizer customized!")

7.2 多模态微调

你可能想知道,“Donut 可以同时处理文本和图像数据吗?”答案是可以的,只要发挥一些创造力。在一次实验中,我将 OCR 输出与表格元数据相结合,以微调 Donut 以完成混合任务。

这是一种简化的方法:

  • 分别预处理文本和图像。
  • 创建一个将元数据与图像对齐的数据集。
  • 通过修改输入管道以接受两者来微调模型。

7.3 迁移学习见解

这可能会让你感到惊讶,但经过微调的 Donut 模型通常可以适应类似的任务,并且只需进行最少的再训练。

例如,我曾经使用在发票上经过微调的模型来处理采购订单。关键是保持数据集的结构和格式相似。

以下是加载预先训练的微调模型并继续训练的方法:

# Load the previously fine-tuned model
model = DonutForDocumentClassification.from_pretrained("./fine_tuned_donut")

# Continue training with new data
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=new_dataset["train"],
    eval_dataset=new_dataset["validation"],
)
trainer.train()

这种方法在某些项目中为我节省了数周的工作时间。

部署和优化 Donut 模型不仅仅是一个技术步骤 — — 这是你的工作开始对现实世界产生影响的地方。这些高级技术呢?

它们是你保持领先地位的方法。在下一节中,让我们总结一下,并看看一些技巧,以确保你的微调模型取得长期成功!

8、结束语

他们说,“结束只是开始”,我非常同意,尤其是当涉及到像 Donut 这样的微调模型时。

在本指南的过程中,我们经历了整个过程 — — 从设置你的环境到部署微调模型,甚至探索高级技术。但问题是:这只是你的起点。

就我个人而言,使用 Donut 是一次有益的经历,我希望本指南能帮助你取得类似的成功。在你前进的过程中,请记住每个模型、数据集和用例都是不同的。

适用于一个项目的方法可能需要针对另一个项目进行调整。关键是保持好奇心,不断尝试,永远不要满足于“足够好”。

那么,你将使用 Donut 创造什么?可能性无穷无尽,我迫不及待地想看看你的旅程会带您去哪里。让我们开始工作吧!


原文链接:Fine-Tune the Donut Model: A Practical Guide

汇智网翻译整理,转载请标明出处

Tags