Unsloth:大模型微调利器
这是我经常看到的问题——随着 RAG(检索增强生成)越来越受欢迎,为什么还要进行微调呢?虽然 RAG 对于许多用例来说都非常棒,但微调在 ML 工具包中仍然占有一席之地。
原因如下:微调允许你从根本上改变模型对特定领域的“思考”。虽然 RAG 在推理时提供上下文,但微调将领域专业知识直接构建到模型的权重中。这在你需要以下情况时尤其有用:
- 一致的领域特定行为
- 更快的推理(无需搜索外部文档)
- 参考文档中难以捕捉的专业知识
此外,还有一个令人信服的成本论据:你可以针对特定任务微调较小的模型,并以托管成本的一小部分实现与大得多的模型相当的性能。
我发现 Reddit 上的这个讨论很棒。看看吧!
1、Unsloth:让微调变得触手可及
训练时间一直是微调的最大障碍之一。这就是 Unsloth 的用武之地——一个新的优化框架,声称可以使 LLM 训练速度提高 30 倍。
Unsloth 效率的秘诀在于深度优化。虽然 PyTorch 和 Transformers 是为跨不同架构的灵活性而构建的,但 Unsloth 采取了更专注的方法。它将 QLoRA 和 Triton 等技术与特定于架构的优化相结合,以从训练过程中获得最大性能。
2、动手实战:SQL 生成模型的微调
让我们通过微调模型来生成 SQL 查询来将其付诸实践。我们将使用 Llama-3.2–3B,这是一个 30 亿参数模型,在能力和资源需求之间取得了良好的平衡。
首先,找到一个好的数据集来微调模型非常重要,找到正确的数据集如此重要的原因是,当你用与手头任务相关的数据训练一个小型语言模型时,它实际上可以胜过更大的模型。我们的目标是创建一个小型、快速的 LLM,它根据表数据生成 SQL 查询。
为此目的最重要的数据集之一称为 Synthetic Text to SQL,它包含超过 105,000 条记录,分为提示 SQL 内容、复杂性等列。
这是数据集的链接:Synthetic Text to SQL 。
3、设置环境
首先,让我们安装必要的软件包。我们需要小心管理我们的 PyTorch 安装:
%%capture
!pip install pip3-autoremove
!pip-autoremove torch torchvision torchaudio -y
!pip install "torch==2.4.0" "xformers==0.0.27.post2" triton torchvision torchaudio
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets
4、加载模型
现在我们将使用 Unsloth 的优化加载器加载我们的基础模型:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Llama-3.2-3B-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
5、设置 PEFT
我们将加载 PEFT(参数高效微调)模型,该模型使用 LoRA(低秩自适应)适配器。
如果不熟悉这些术语,请不要担心。LoRA 适配器允许我们在微调期间仅更新 1-10% 的模型参数。如果没有它们,我们就需要重新训练整个模型,这将耗费更多时间、计算量更大、成本更高。
Unsloth 提供了这些推荐设置以获得最佳性能。虽然我们将在本教程中使用它们的默认配置,但你可以根据需要随意探索和调整这些参数。
model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
use_gradient_checkpointing = "unsloth", # 4x longer contexts auto supported!
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
6、数据
现在,事情可能会变得有点复杂有点棘手,具体取决于你使用的数据集。每个数据集都不同,但它们的格式相同,以便大型语言模型可以理解它。Llama3.2 使用alpaca ( 羊驼 )提示,如下所示:
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
[The task or question you want the model to perform/answer]
### Input:
[Additional context or information needed to complete the task. This can be empty if the instruction is self-contained]
### Response:
[The expected output or answer you want the model to learn]
对于我们的 SQL 数据库项目,我们特别关注三个组件:
- SQL 查询提示
- 生成的 SQL 代码
- 代码说明
from datasets import Dataset, load_dataset
# Define the prompt template with variables matching the loop content
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{response}"""
# Set the EOS token (assuming the tokenizer is already defined)
EOS_TOKEN = tokenizer.eos_token
# Formatting function to apply the prompt template to the dataset
def formatting_prompts_func(examples):
company_databases = examples["sql_context"]
prompts = examples["sql_prompt"]
sqls = examples["sql"]
explanations = examples["sql_explanation"]
texts = []
for company_database, prompt, sql, explanation in zip(company_databases, prompts, sqls, explanations):
# Substitute the correct placeholders
text = alpaca_prompt.format(
instruction=prompt,
input=company_database,
response=sql + " " + explanation
) + EOS_TOKEN
texts.append(text)
return {"text": texts} # Ensure the formatted text is returned as a "text" field
# Load dataset and map formatting function to add prompts
ds = load_dataset("gretelai/synthetic_text_to_sql")
formatted_ds = ds.map(formatting_prompts_func, batched=True) # Apply formatting
# Select the 'train' split from the formatted dataset
train_dataset = formatted_ds['train']
7、训练配置
有很多参数可以使用,所有这些都可以描述。例如,有最大步数,它告诉我们要执行多少个训练步骤。种子是一个随机数生成器。我们过去能够重现结果,而热身步骤会随着时间的推移逐渐提高学习率。
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from trl import SFTTrainer
# Trainer setup
trainer = SFTTrainer(
model=model, # Ensure model is defined
tokenizer=tokenizer, # Ensure tokenizer is defined
train_dataset=train_dataset, # Use the 'train' split from formatted_ds
dataset_text_field="text", # This is the field we created with formatted prompts
max_seq_length=max_seq_length, # Ensure max_seq_length is defined
dataset_num_proc=2,
packing=False, # Can make training 5x faster for short sequences.
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none", # Disable WANDB logging
)
)
现在我们已经设置好了一切,让我们运行它。就是这样。
trainer_stats = trainer.train()
原文链接:Fine-Tune LLMs with Unsloth
汇智网翻译整理,转载请标明出处