Unsloth:大模型微调利器

这是我经常看到的问题——随着 RAG(检索增强生成)越来越受欢迎,为什么还要进行微调呢?虽然 RAG 对于许多用例来说都非常棒,但微调在 ML 工具包中仍然占有一席之地。

原因如下:微调允许你从根本上改变模型对特定领域的“思考”。虽然 RAG 在推理时提供上下文,但微调将领域专业知识直接构建到模型的权重中。这在你需要以下情况时尤其有用:

  • 一致的领域特定行为
  • 更快的推理(无需搜索外部文档)
  • 参考文档中难以捕捉的专业知识

此外,还有一个令人信服的成本论据:你可以针对特定任务微调较小的模型,并以托管成本的一小部分实现与大得多的模型相当的性能。

我发现 Reddit 上的这个讨论很棒。看看吧!

1、Unsloth:让微调变得触手可及

训练时间一直是微调的最大障碍之一。这就是 Unsloth 的用武之地——一个新的优化框架,声称可以使 LLM 训练速度提高 30 倍。

Unsloth 效率的秘诀在于深度优化。虽然 PyTorch 和 Transformers 是为跨不同架构的灵活性而构建的,但 Unsloth 采取了更专注的方法。它将 QLoRA 和 Triton 等技术与特定于架构的优化相结合,以从训练过程中获得最大性能。

2、动手实战:SQL 生成模型的微调

让我们通过微调模型来生成 SQL 查询来将其付诸实践。我们将使用 Llama-3.2–3B,这是一个 30 亿参数模型,在能力和资源需求之间取得了良好的平衡。

首先,找到一个好的数据集来微调模型非常重要,找到正确的数据集如此重要的原因是,当你用与手头任务相关的数据训练一个小型语言模型时,它实际上可以胜过更大的模型。我们的目标是创建一个小型、快速的 LLM,它根据表数据生成 SQL 查询。

为此目的最重要的数据集之一称为 Synthetic Text to SQL,它包含超过 105,000 条记录,分为提示 SQL 内容、复杂性等列。

这是数据集的链接:Synthetic Text to SQL

3、设置环境

首先,让我们安装必要的软件包。我们需要小心管理我们的 PyTorch 安装:

%%capture
!pip install pip3-autoremove
!pip-autoremove torch torchvision torchaudio -y
!pip install "torch==2.4.0" "xformers==0.0.27.post2" triton torchvision torchaudio
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets

4、加载模型

现在我们将使用 Unsloth 的优化加载器加载我们的基础模型:

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


model, tokenizer = FastLanguageModel.from_pretrained(
 model_name = "unsloth/Llama-3.2-3B-bnb-4bit",
 max_seq_length = max_seq_length,
 dtype = dtype,
 load_in_4bit = load_in_4bit,
)

5、设置 PEFT

我们将加载 PEFT(参数高效微调)模型,该模型使用 LoRA(低秩自适应)适配器。

如果不熟悉这些术语,请不要担心。LoRA 适配器允许我们在微调期间仅更新 1-10% 的模型参数。如果没有它们,我们就需要重新训练整个模型,这将耗费更多时间、计算量更大、成本更高。

Unsloth 提供了这些推荐设置以获得最佳性能。虽然我们将在本教程中使用它们的默认配置,但你可以根据需要随意探索和调整这些参数。

model = FastLanguageModel.get_peft_model(
 model,
 r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
 target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
 "gate_proj", "up_proj", "down_proj",],
 lora_alpha = 16,
 lora_dropout = 0, # Supports any, but = 0 is optimized
 bias = "none", # Supports any, but = "none" is optimized
 use_gradient_checkpointing = "unsloth", # 4x longer contexts auto supported!
 random_state = 3407,
 use_rslora = False, # We support rank stabilized LoRA
 loftq_config = None, # And LoftQ
)

6、数据

现在,事情可能会变得有点复杂有点棘手,具体取决于你使用的数据集。每个数据集都不同,但它们的格式相同,以便大型语言模型可以理解它。Llama3.2 使用alpaca ( 羊驼 )提示,如下所示:

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
[The task or question you want the model to perform/answer]

### Input:
[Additional context or information needed to complete the task. This can be empty if the instruction is self-contained]

### Response:
[The expected output or answer you want the model to learn]

对于我们的 SQL 数据库项目,我们特别关注三个组件:

  • SQL 查询提示
  • 生成的 SQL 代码
  • 代码说明
from datasets import Dataset, load_dataset

# Define the prompt template with variables matching the loop content
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{response}"""

# Set the EOS token (assuming the tokenizer is already defined)
EOS_TOKEN = tokenizer.eos_token

# Formatting function to apply the prompt template to the dataset
def formatting_prompts_func(examples):
 company_databases = examples["sql_context"]
 prompts = examples["sql_prompt"]
 sqls = examples["sql"]
 explanations = examples["sql_explanation"]
 texts = []
 for company_database, prompt, sql, explanation in zip(company_databases, prompts, sqls, explanations):
 # Substitute the correct placeholders
 text = alpaca_prompt.format(
 instruction=prompt,
 input=company_database,
 response=sql + " " + explanation
 ) + EOS_TOKEN
 texts.append(text)
 return {"text": texts} # Ensure the formatted text is returned as a "text" field

# Load dataset and map formatting function to add prompts
ds = load_dataset("gretelai/synthetic_text_to_sql")
formatted_ds = ds.map(formatting_prompts_func, batched=True) # Apply formatting

# Select the 'train' split from the formatted dataset
train_dataset = formatted_ds['train']

7、训练配置

有很多参数可以使用,所有这些都可以描述。例如,有最大步数,它告诉我们要执行多少个训练步骤。种子是一个随机数生成器。我们过去能够重现结果,而热身步骤会随着时间的推移逐渐提高学习率。

from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from trl import SFTTrainer

# Trainer setup
trainer = SFTTrainer(
 model=model, # Ensure model is defined
 tokenizer=tokenizer, # Ensure tokenizer is defined
 train_dataset=train_dataset, # Use the 'train' split from formatted_ds
 dataset_text_field="text", # This is the field we created with formatted prompts
 max_seq_length=max_seq_length, # Ensure max_seq_length is defined
 dataset_num_proc=2,
 packing=False, # Can make training 5x faster for short sequences.
 args=TrainingArguments(
 per_device_train_batch_size=2,
 gradient_accumulation_steps=4,
 warmup_steps=5,
 max_steps=60,
 learning_rate=2e-4,
 fp16=not is_bfloat16_supported(),
 bf16=is_bfloat16_supported(),
 logging_steps=1,
 optim="adamw_8bit",
 weight_decay=0.01,
 lr_scheduler_type="linear",
 seed=3407,
 output_dir="outputs",
 report_to="none", # Disable WANDB logging
 )
)

现在我们已经设置好了一切,让我们运行它。就是这样。

trainer_stats = trainer.train()

原文链接:Fine-Tune LLMs with Unsloth

汇智网翻译整理,转载请标明出处