Donut模型微调实用指南
“精通始于基础的结束。”
我个人认为,这是机器学习中的一句口头禅,尤其是在使用 Donut(文档理解转换器)等模型时。
如果你正在阅读本文,可能知道预训练模型非常强大,但让我告诉你,它们不是魔杖。
我曾参与过一些项目,其中现成的模型在处理特定领域的文档(例如高度风格化的发票或多语言合同)时失败了。
事情是这样的:Donut 在 OCR(光学字符识别)和信息提取任务方面表现出色,尤其是在处理复杂的文档布局时。
但是,当你的数据集与预训练数据集有很大差异时,微调就变得至关重要。我在处理法律文档流程时亲眼目睹了这一点。
预训练的 Donut 模型错过了法律术语和格式特有的关键注释,但对其进行微调带来了显着的改进。
在本指南中,我将向你介绍我用来微调 Donut 模型的确切步骤,以便你可以根据特定领域的需求对其进行调整。
在本指南结束时,你不仅会拥有一个经过微调的模型,而且还会有信心为未来的项目定制它。
1、先决条件
当我开始使用 Donut 模型时,我意识到这不仅仅是运行脚本,而是准备工作。以下是你开始使用所需的条件:
1.1 技术技能
你应该已经熟悉 Python、PyTorch 和 Hugging Face Transformers。如果你之前已经在生产中部署过模型,那么就没问题了。如果没有,请不要担心——我将指导你完成实际方面。
1.2 环境设置
根据我的经验,正确的设置可以为你节省数小时的调试时间。以下是我推荐的:
硬件:具有至少 16GB 内存的 GPU 效果最佳。我个人使用了 NVIDIA RTX 3090,它完美地完成了任务。AWS 或 Google Colab Pro+ 等云设置也是不错的选择。
软件:
- Python (≥3.8)
- PyTorch (≥1.12)
- Hugging Face Transformers (≥4.30)
这是一个快速安装脚本:
# Install Python dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
pip install transformers datasets accelerate
# Install additional libraries
pip install pillow albumentations
1.3 数据集要求
现在,让我们谈谈数据。微调的好坏取决于你输入的数据集。在一个项目中,我处理了多种语言的扫描收据。准备这个数据集是一个挑战,但它教会了我一些东西。
- 结构:你的数据集应包括 JSON 格式的文档图像和相应的标注。 Donut 需要类似 COCO 的结构,其中每个标注都映射到文档中的特定字段。
- 预处理工具:用于初始 OCR 的 Tesseract 和用于布局分析的 LayoutLMv3 等工具可以在预处理过程中起到救星作用。就我个人而言,我依靠 Python 库(如 Pillow)进行图像处理,使用 json 进行标注处理。
以下是构建数据集的代码片段:
import os
import json
from PIL import Image
# Define dataset paths
image_dir = "./dataset/images"
annotations_path = "./dataset/annotations.json"
# Verify dataset format
for file_name in os.listdir(image_dir):
img_path = os.path.join(image_dir, file_name)
if img_path.endswith(".png") or img_path.endswith(".jpg"):
try:
# Load image
img = Image.open(img_path)
print(f"Loaded {file_name}: {img.size}")
except Exception as e:
print(f"Error loading {file_name}: {e}")
# Example annotation structure
annotation_example = {
"image_id": "0001",
"annotations": [
{"field": "Invoice Number", "value": "12345", "bbox": [100, 200, 300, 400]},
{"field": "Date", "value": "2025-01-05", "bbox": [400, 200, 600, 400]},
],
}
with open(annotations_path, "w") as f:
json.dump(annotation_example, f)
这些都是先决条件。设置好环境并准备好数据集后,就可以开始进行微调过程了,我们将在下文中介绍。相信我,当你的模型开始提供满足你需求的结果时,你在这里付出的努力将得到回报。
2、准备数据集
“好的数据胜过好的算法。”
我在使用机器学习模型的工作中一次又一次地看到了这一点,尤其是像 Donut 模型这样微妙的东西。
当我第一次对 Donut 进行微调时,数据集准备阶段改变了游戏规则。以下是如何以正确的方式设置数据,避免我遇到的陷阱。
2.1 数据集格式
我做的第一件事是确保我的数据集的格式与 Donut 兼容。该模型需要 JSON 结构,并为文档图像提供类似 COCO 的标注。
这种格式提供了图像及其相应的标签,使模型易于解析和学习。
这是我用来验证数据集兼容性的 Python 代码片段:
import json
# Load dataset annotation file
annotation_file = "./annotations.json"
with open(annotation_file, "r") as f:
annotations = json.load(f)
# Check if annotations match the expected structure
if "images" in annotations and "annotations" in annotations:
print("Dataset structure is valid.")
else:
raise ValueError("Dataset format is incompatible with Donut.")
如果你的数据集不遵循此格式,则需要对其进行转换。我不得不为我的一个项目编写自定义脚本,以将“发票号”或“总金额”等字段映射到正确的 JSON 格式。相信我,在这里投入时间将为你节省训练期间的麻烦。
2.2 预处理步骤
事情是这样的:预处理不仅仅是清理数据——它是为了最大限度地提高模型性能而对其进行调整。
在一个项目中,我必须处理文本歪斜或嘈杂的多语言发票。关键是要提前处理这些变化。
- 降噪:高斯模糊等技术可以减少图像伪影。
- 多语言处理:确保你的标注和数据支持日语或阿拉伯语等语言的 Unicode。
以下是将图像转换为张量并对其进行规范化的实际示例:
from PIL import Image
import torch
from torchvision import transforms
# Define transformation pipeline
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3), # Convert to 3 channels
transforms.Resize((256, 256)), # Resize to uniform dimensions
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize
])
# Load and preprocess an image
image_path = "./dataset/images/sample.jpg"
image = Image.open(image_path)
tensor_image = transform(image)
print(f"Image Tensor Shape: {tensor_image.shape}")
2.3 数据增强
我学到的一件事是,增强文档图像可以显著提高泛化能力。
我使用过 Albumentations 等库来添加噪点、旋转图像和裁剪边缘 — — 这些都是现实世界中常见的变化。
以下是如何实现增强管道:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
# Define augmentation pipeline
augmentation = A.Compose([
A.Rotate(limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.GaussianBlur(p=0.3),
A.Normalize(mean=(0.5,), std=(0.5,)),
ToTensorV2()
])
# Apply augmentation to an image
image_path = "./dataset/images/sample.jpg"
image = cv2.imread(image_path)
augmented = augmentation(image=image)
augmented_image = augmented["image"]
print(f"Augmented Image Shape: {augmented_image.shape}")
我使用了这种精确的方法在一个包含财务文件的项目中,它明显减少了过度拟合。
2.4 拆分数据
正确拆分数据集至关重要。我个人曾遇到过不恰当的拆分导致评估指标出现偏差的情况,所以现在我总是确保训练、验证和测试集之间的平衡分布。
以下是使用 sklearn 的简单而可靠的方法:
from sklearn.model_selection import train_test_split
# Example dataset
images = ["img1.jpg", "img2.jpg", "img3.jpg"] # Replace with your actual dataset
annotations = ["ann1.json", "ann2.json", "ann3.json"]
# Split data into train, validation, and test sets
train_imgs, val_test_imgs, train_anns, val_test_anns = train_test_split(
images, annotations, test_size=0.4, random_state=42
)
val_imgs, test_imgs, val_anns, test_anns = train_test_split(
val_test_imgs, val_test_anns, test_size=0.5, random_state=42
)
print(f"Training set: {len(train_imgs)} images")
print(f"Validation set: {len(val_imgs)} images")
print(f"Test set: {len(test_imgs)} images")
我遵循的一个很好的经验法则是 70% 的训练、15% 的验证和 15% 的测试。它在训练稳健性和评估可靠性之间取得了适当的平衡。
这就是准备数据集的全部内容!通过正确格式化、预处理和拆分数据,你将为有效的微调奠定基础。
相信我,这项基础工作值得每一分钟——这就是在现实场景中挣扎的模型与表现出色的模型的区别所在。接下来:微调 Donut 模型!
3、微调 Donut 模型
微调 Donut 不仅仅是一项技术任务;这是一个过程,从模型选择到超参数调整,你做出的每个决定都会影响结果。
我已经经历过多次这样的旅程,让我告诉你——这是艺术与科学的融合。让我们一步一步来分析。
3.1 选择预训练模型
当我第一次开始使用 Donut 时,最关键的问题是:我应该使用哪个预训练检查点?
Donut 在 Hugging Face 模型中心提供了多个版本,每个版本都针对 OCR 或信息提取等特定任务进行了优化。
这是我所寻找的:
- 领域相关性:如果你的任务涉及发票,请从针对财务文件进行训练的模型开始。
- 模型大小:像 Donut-Small 这样的较大模型可以处理复杂的布局,但需要更多的计算能力。
你可以在此处浏览 Hugging Face 中心上的预训练模型。就我个人而言,我发现 naver-clova-ix/donut-base
是通用任务的绝佳起点。
3.2 模型配置
现在到了有趣的部分——调整模型以适合你的数据集。在我的一个项目中,我注意到默认学习率过于激进,导致训练早期过度拟合。调整超参数会产生很大的不同。
以下是修改配置的方法:
from transformers import DonutConfig
# Load pre-trained configuration
config = DonutConfig.from_pretrained("naver-clova-ix/donut-base")
# Adjust hyperparameters
config.learning_rate = 5e-5 # Reduce learning rate for fine-tuning
config.batch_size = 16 # Set batch size based on GPU capacity
config.num_train_epochs = 10 # Number of epochs
print(config)
3.3 加载模型
使用 Hugging Face 加载预训练模型非常简单。但我学到了一点:在微调之前,一定要在样本输入上测试模型。这种健全性检查为我节省了数小时的调试时间。
以下是初始化 Donut 的代码片段:
from transformers import DonutForDocumentClassification
# Load the pre-trained model
model = DonutForDocumentClassification.from_pretrained(
"naver-clova-ix/donut-base",
config=config
)
# Check model output on a sample input
sample_input = {"pixel_values": torch.rand(1, 3, 256, 256)} # Replace with your preprocessed data
outputs = model(**sample_input)
print(outputs)
3.4 训练管道
你可能想知道,“微调 Donut 的最佳方法是什么?”我个人认为 Hugging Face 的 Trainer API 是一个改变游戏规则的工具。
它简化了训练,同时允许你自定义关键组件,如损失函数和评估指标。
以下是一个详细示例:
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
# Load dataset
dataset = load_dataset("your_custom_dataset")
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
learning_rate=5e-5,
per_device_train_batch_size=8,
num_train_epochs=10,
evaluation_strategy="steps",
save_steps=500,
logging_dir="./logs"
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=None, # Donut doesn't use tokenizers
)
# Start training
trainer.train()
我喜欢这种方法的地方在于它如何记录所有内容(从损失值到评估指标),以便你可以实时监控进度。
3.5 处理大型数据集
如果你正在处理大量数据集,这里有一个我通过艰苦努力学到的技巧:不要尝试一次将所有内容加载到内存中。
相反,使用梯度累积和混合精度训练等技术来优化资源使用。
以下是如何实现梯度累积:
training_args = TrainingArguments(
output_dir="./results",
gradient_accumulation_steps=4, # Accumulate gradients over 4 steps
fp16=True, # Enable mixed precision
per_device_train_batch_size=2, # Reduce batch size to fit in memory
num_train_epochs=10,
learning_rate=5e-5,
)
对于分布式训练,Hugging Face 的 accelerate
库可以帮助你跨多个 GPU 进行扩展。相信我,这在对大型数据集进行微调时非常有价值。
微调就是将通用模型变成真正适合你需求的模型。
这个过程一开始可能让人望而生畏,但一旦你看到结果——一个比其他任何东西都更能理解你的数据的模型——你就会觉得每一刻的努力都是值得的。
接下来,我们将深入评估你的微调模型,并确保它已准备好投入生产。
4、评估微调模型
当我第一次开始评估我的微调模型时,我认为这个过程很简单:运行一些测试,获取数字,然后继续。
但问题是:评估不仅仅是数字——它是为了了解你的模型在做什么是正确的,更重要的是,它在哪里出了问题。
4.1 要使用的指标
根据我的经验,选择正确的指标可以成就或破坏你的评估过程。对于涉及文档理解的任务,这些是我依赖的首选指标:
- 字符错误率 (CER):非常适合衡量 OCR 性能。它告诉你模型相对于基本事实有多少个字符是错误的。
- 单词错误率 (WER):当重点关注文本级准确性时,这是我使用的指标。
- BLEU 分数:非常适合结构化文本输出很重要的任务。
以下是计算 CER 的快速代码片段:
def calculate_cer(prediction, ground_truth):
from jiwer import cer
return cer(ground_truth, prediction)
# Example
pred = "Hello Wrold"
gt = "Hello World"
print(f"Character Error Rate: {calculate_cer(pred, gt)}")
4.2 评估管道
在评估经过微调的 Donut 模型时,我喜欢自动化整个管道。
手动对验证集进行测试非常耗时,而且容易出错。这是我使用的管道:
from transformers import pipeline
import json
# Load model
model_path = "./results/checkpoint"
donut_pipeline = pipeline("document-question-answering", model=model_path)
# Load validation dataset
with open("./validation_data.json", "r") as f:
validation_data = json.load(f)
# Evaluate model
results = []
for data in validation_data:
image = data["image"]
ground_truth = data["annotations"]
prediction = donut_pipeline(image)
cer_score = calculate_cer(prediction, ground_truth)
results.append({"prediction": prediction, "cer": cer_score})
# Print average CER
avg_cer = sum(r["cer"] for r in results) / len(results)
print(f"Average CER: {avg_cer}")
4.3 可视化
数字很棒,但没有什么比看到模型在实际示例上的表现更好。当我调试或展示结果时,我会在预测和基本事实之间生成视觉比较。
以下是可视化方法:
import matplotlib.pyplot as plt
from PIL import Image
# Load an image and its annotations
image_path = "./sample_image.jpg"
prediction = "Total: $123.45"
ground_truth = "Total: $123.40"
# Plot the image with annotations
image = Image.open(image_path)
plt.imshow(image)
plt.title(f"Prediction: {prediction}\nGround Truth: {ground_truth}")
plt.axis("off")
plt.show()
我发现这对于向利益相关者解释模型行为非常有用——或者当事情没有按预期进行时,仅向我自己解释!
5、调试和优化
如果说我学到了一件事,那就是:调试模型既是一门艺术,也是一门科学。我参与的每个项目都存在不少问题,下面是我解决这些问题的方法。
5.1 常见的挑战
过拟合:在一个项目中,我的模型在训练集上表现完美,但在验证数据上却惨遭失败。罪魁祸首是什么?过度拟合。
解决方法:使用 dropout、数据增强或提前停止等技术。
欠拟合:当你的模型没有从数据中学到足够的知识时,就会发生这种情况。
解决方法:尝试更深层次的架构或训练更多时期。
嘈杂的数据:嘈杂的标注不止一次破坏了我的训练过程。
解决方法:清理数据集并使用工具标记异常值。
5.2 优化技术
你可能想知道,“微调超参数的最佳方法是什么?”我已经成功使用了动态学习率计划和权重衰减。以下是如何实现这些的:
from transformers import get_scheduler
# Define optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = get_scheduler(
"linear", optimizer=optimizer, num_warmup_steps=500, num_training_steps=10000
)
5.3 错误分析
当出现问题时(这种情况经常发生),我会依靠工具深入挖掘错误。例如,生成混淆矩阵对于了解模型遇到的问题非常有帮助。
以下是生成详细错误日志的方法:
import pandas as pd
# Example error analysis
error_log = [{"image": "img1.jpg", "pred": "Total: $123", "gt": "Total: $124"}]
# Save as CSV
df = pd.DataFrame(error_log)
df.to_csv("error_log.csv", index=False)
print("Error log saved!")
调试和优化不仅仅是流程中的步骤 — — 它们是您真正了解模型的地方。
一旦你解决了常见的挑战、微调了超参数并分析了错误,就会看到你的模型表现得比想象的更好。
接下来,让我们谈谈如何部署经过微调的模型,以便它在现实世界中大放异彩!
6、部署经过微调的模型
这是我逐渐意识到的一个事实:训练模型只是旅程的一半;部署是奇迹发生的地方。
我曾经参与过一些项目,其中经过微调的模型在受控测试中表现良好,但在实际应用中却表现不佳。
这就是为什么这一步至关重要——让我们确保你经过微调的 Donut 模型已准备好部署。
6.1 导出模型
在部署之前,你需要保存经过微调的模型。我经常使用 Hugging Face 的 save_pretrained
方法,因为它可以确保跨各种应用程序的兼容性。
以下是导出模型的代码片段:
from transformers import DonutForDocumentClassification
# Save the fine-tuned model
output_dir = "./fine_tuned_donut"
model.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")
这种方法的优点在于它同时打包了模型权重和配置,使重新加载变得无缝。
6.2 集成到应用程序中
在部署模型时,我发现 FastAPI 是救星。它轻量级、易于设置,并且非常适合为 ML 模型提供服务。这是我为文档分析构建的简单 API:
from fastapi import FastAPI, File, UploadFile
from transformers import pipeline
from PIL import Image
app = FastAPI()
# Load your fine-tuned model
donut_pipeline = pipeline("document-question-answering", model="./fine_tuned_donut")
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
# Read and process the uploaded file
image = Image.open(file.file)
prediction = donut_pipeline(image)
return {"prediction": prediction}
运行此应用程序,你将拥有一个可用的 API 来为您的模型提供服务。在我的一个项目中,此 API 成为一个大型系统的支柱,该系统可自动为客户处理发票。
6.3 推理优化
情况是这样的:部署模型是一回事,但优化模型以进行实时推理又是另一回事。根据我的经验,ONNX 和 TensorRT 等工具可以大大减少延迟。
以下是我将 PyTorch 模型转换为 ONNX 以加快推理速度的方法:
import torch
# Export to ONNX
dummy_input = torch.rand(1, 3, 256, 256) # Replace with your input size
onnx_path = "./donut_model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=11,
input_names=["input"],
output_names=["output"],
)
print(f"Model exported to {onnx_path}")
通过这种方式,我成功部署了能够以最小延迟处理每分钟数千个请求的模型。
7、高级主题
如果你和我一样,那么你总是在寻找突破界限的方法。以下是我用来将 Donut 模型提升到新水平的一些高级技术。
7.1 自定义分词器
Donut 不依赖于传统的分词器,但有时你可能需要调整其输入处理以适应特定领域的任务。
例如,在一个项目中,我需要模型来处理特定的发票布局。自定义分词器至关重要。
以下是调整 Donut 分词器的示例:
from transformers import AutoTokenizer
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("naver-clova-ix/donut-base")
# Customize tokenizer
special_tokens = ["[CUSTOM_FIELD]", "[ANOTHER_FIELD]"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
model.resize_token_embeddings(len(tokenizer))
print("Tokenizer customized!")
7.2 多模态微调
你可能想知道,“Donut 可以同时处理文本和图像数据吗?”答案是可以的,只要发挥一些创造力。在一次实验中,我将 OCR 输出与表格元数据相结合,以微调 Donut 以完成混合任务。
这是一种简化的方法:
- 分别预处理文本和图像。
- 创建一个将元数据与图像对齐的数据集。
- 通过修改输入管道以接受两者来微调模型。
7.3 迁移学习见解
这可能会让你感到惊讶,但经过微调的 Donut 模型通常可以适应类似的任务,并且只需进行最少的再训练。
例如,我曾经使用在发票上经过微调的模型来处理采购订单。关键是保持数据集的结构和格式相似。
以下是加载预先训练的微调模型并继续训练的方法:
# Load the previously fine-tuned model
model = DonutForDocumentClassification.from_pretrained("./fine_tuned_donut")
# Continue training with new data
trainer = Trainer(
model=model,
args=training_args,
train_dataset=new_dataset["train"],
eval_dataset=new_dataset["validation"],
)
trainer.train()
这种方法在某些项目中为我节省了数周的工作时间。
部署和优化 Donut 模型不仅仅是一个技术步骤 — — 这是你的工作开始对现实世界产生影响的地方。这些高级技术呢?
它们是你保持领先地位的方法。在下一节中,让我们总结一下,并看看一些技巧,以确保你的微调模型取得长期成功!
8、结束语
他们说,“结束只是开始”,我非常同意,尤其是当涉及到像 Donut 这样的微调模型时。
在本指南的过程中,我们经历了整个过程 — — 从设置你的环境到部署微调模型,甚至探索高级技术。但问题是:这只是你的起点。
就我个人而言,使用 Donut 是一次有益的经历,我希望本指南能帮助你取得类似的成功。在你前进的过程中,请记住每个模型、数据集和用例都是不同的。
适用于一个项目的方法可能需要针对另一个项目进行调整。关键是保持好奇心,不断尝试,永远不要满足于“足够好”。
那么,你将使用 Donut 创造什么?可能性无穷无尽,我迫不及待地想看看你的旅程会带您去哪里。让我们开始工作吧!
原文链接:Fine-Tune the Donut Model: A Practical Guide
汇智网翻译整理,转载请标明出处