Llama-3.1微调实现函数调用

MODEL ZOO Nov 10, 2024

微调大型语言模型 (LLM) 已成为将预训练模型适应特定任务和领域的关键。虽然预训练模型提供了令人印象深刻的通用功能,但微调使我们能够针对特殊用例(如翻译或函数调用)、特定领域的任务(如法律或金融)以及自定义应用程序对其进行优化。使用 LLM 的主要挑战之一是微调所需的计算资源。传统的微调方法需要大量的 GPU 内存和计算时间。这就是 LoRA(低秩自适应)等技术和 Unsloth 等框架发挥作用的地方,它们使该过程更加高效和易于访问。

在本综合指南中,我们将探讨如何使用 Unsloth(一种专为优化和微调大型语言模型而设计的专用工具包)微调 Llama-3.1–8B 模型以实现函数调用功能。我们将利用 LoRA 进行有效的参数更新,集成W&B 进行实验跟踪,然后使用 vLLM 进行高性能模型推理和服务。

1、为什么函数调用对语言模型很重要

语言模型中的函数调用使 AI 能够直接与外部系统交互并自主执行实际任务。通过集成外部函数和 API,开发人员可以构建不仅仅是生成文本的应用程序——它们还可以解决特定问题、检索信息并根据用户输入执行操作。

小型语言模型 (SLM) 中的函数调用特别有价值,因为它允许这些模型在更易于访问的硬件上有效地处理函数调用任务。在这里,函数调用功能可以实现以下任务:

  • 将自然语言转换为 API 调用或生成有效的数据库查询。
  • 构建与实时数据交互的对话知识检索系统。

与大型模型不同,经过适当微调后,SLM 可以用更少的资源实现这些交互。像 LoRA 这样的技术允许我们添加特定于任务的功能而无需大量内存,从而使 SLM 适用于实时应用程序。

1、工具和技术概述

Unsloth

针对 LLM 微调的优化框架,提供:

  • 训练速度提高 30 倍,内存使用量减少 60%
  • 支持多种硬件设置(NVIDIA、AMD 和 Intel GPU)
  • 智能权重优化技术,提高内存效率
  • 与流行的微调方法(如 Flash-Attention 2)集成
  • 与主要 LLM(Mistral、Llama、Gemma)兼容
  • 在本地 GPU 和 Google Colab 上高效运行
LoRA(低秩自适应)

一种参数高效的微调方法,通过向现有权重添加小型可训练秩分解矩阵来减少内存需求,从而实现特定于任务的自适应而无需修改所有模型参数。

Weights & Bias(W&B)

用于跟踪训练指标、可视化性能和管理实验的监控平台。

vLLM

一个开源库,引入了 PagedAttention 和连续批处理等创新技术,以优化内存使用并提高吞吐量,从而实现高效的 LLM 服务和推理优化。

2、设置W & B

为了监控和记录模型的微调过程,我们首先配置 W&B。以下函数处理身份验证并使用指定的项目和运行名称初始化新运行。

import os
import wandb
from dotenv import load_dotenv
load_dotenv()

def setup_wandb(project_name: str, run_name: str):
    # Set up your API KEY
    try:
        api_key = os.getenv("WANDB_API_KEY")
        wandb.login(key=api_key)
        print("Successfully logged into WandB.")
    except KeyError:
        raise EnvironmentError("WANDB_API_KEY is not set in the environment variables.")
    except Exception as e:
        print(f"Error logging into WandB: {e}")
    
    # Optional: Log models
    os.environ["WANDB_LOG_MODEL"] = "checkpoint"
    os.environ["WANDB_WATCH"] = "all"
    os.environ["WANDB_SILENT"] = "true"
    
    # Initialize the WandB run
    try:
        wandb.init(project=project_name, name=run_name)
        print(f"WandB run initialized: Project - {project_name}, Run - {run_name}")
    except Exception as e:
        print(f"Error initializing WandB run: {e}")


setup_wandb(project_name="<project_name>", run_name="<run_name>")

3、HuggingFace 身份验证

为了下载Salesforce 函数调用数据集并随后上传经过微调的模型,我们首先需要通过从环境变量中安全加载和验证 Hugging Face 令牌来验证我们对 Hugging Face Hub 的访问权限。

from huggingface_hub import login

hf_token = os.getenv("HUGGINGFACE_TOKEN")
if hf_token is None:
    raise EnvironmentError("HUGGINGFACE_TOKEN is not set in the environment variables.")
login(hf_token)

4、加载基础模型

此设置加载 Llama-3.1–8B-Instruct 模型及其使用 Unsloth 的 FastLanguageModel 的分词器,具有:

  • 可配置序列长度 ( max_seq_length=2048) 指定模型可以处理的最大输入序列长度,通常称为模型的上下文长度。
  • 自动 dtype 检测 ( dtype=None) 可实现灵活优化。
  • 4 位量化选项 ( load_in_4bit=False),允许默认精度以获得更高的模型保真度。
import torch
from unsloth import FastLanguageModel

max_seq_length = 2048     # Unsloth auto supports RoPE Scaling internally!
dtype = None              # None for auto detection
load_in_4bit = False      # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",  
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

5、配置 LoRA 以实现高效微调

此代码使用 LoRA 配置 Llama-3.1 模型以实现参数高效微调 (PEFT),从而实现对选定层进行高效微调,同时通过仅调整特定组件而不是整个模型来最大限度地减少内存使用并加速训练。以下是每个关键参数的细分:

  • r=16:设置 LoRA 矩阵的秩,平衡模型性能和内存使用。
  • target_modules:识别“q_proj”和“k_proj”等层,以进行有针对性的微调。
  • lora_alpha=16:控制缩放因子以避免过度拟合。
  • lora_dropout=0:将 dropout 设置为零,以进行一致的训练。
  • use_gradient_checkpointing="unsloth":最大限度地减少内存使用,尤其是对于较长的上下文长度。
  • bias="none”:忽略额外的偏差项
  • random_state=3407:确保可重复的训练运行。
  • use_rslora=False:禁用对等级敏感的 LoRA,针对标准、不太复杂的任务进行优化。
  • loftq_config=None:禁用 LoftQ,否则它会使用高级初始化来提高准确性,但开始时会占用更多内存。

此设置允许在保持模型性能的同时进行更高效的资源微调。

model = FastLanguageModel.get_peft_model(
    model,
    r=16,   # LoRA rank - suggested values: 8, 16, 32, 64, 128
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,   # Supports any, but = 0 is optimized
    bias="none",      # Supports any, but = "none" is optimized
    use_gradient_checkpointing="unsloth",  # Ideal for long context tuning
    random_state=3407,
    use_rslora=False,   # Disable rank-sensitive LoRA for simpler tasks
    loftq_config=None   # No LoftQ, for standard fine-tuning
)

6、加载和处理数据集

对于我们的训练,我们将使用 Salesforce/xlam-function-calling-60k,它是专门为函数调用任务设计的。

要开始微调,我们将从数据集中可管理的 15K 个样本子集开始,而不是使用整个数据集。这使我们能够尽早评估模型的性能并更有效地进行调整。10-20K 的样本量达到了良好的平衡:它足够大,可以产生有意义的见解,同时保持内存和训练时间要求合理。

from datasets import load_dataset

# Loading the dataset
dataset = load_dataset("Salesforce/xlam-function-calling-60k", split="train", token=hf_token)

# Selecting a subset of 15K samples for fine-tuning
dataset = dataset.select(range(15000))
print(f"Using a sample size of {len(dataset)} for fine-tuning.")

使用 Unsloth 的聊天模板,我们将原始数据转换为与模型兼容的标记。此步骤标准化了函数调用提示,使模型能够以结构化的方式理解和预测输出。

from unsloth.chat_templates import get_chat_template

# Initialize the tokenizer with the chat template and mapping
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3", 
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"}, # ShareGPT style
    map_eos_token = True,        # Maps <|im_end|> to <|eot_id|> instead
)

def formatting_prompts_func(examples):
    convos = []
    
    # Iterate through each item in the batch (examples are structured as lists of values)
    for query, tools, answers in zip(examples['query'], examples['tools'], examples['answers']):
        tool_user = {
            "content": f"You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\n{tools}",
            "role": "system"
        }
        ques_user = {
            "content": f"{query}",
            "role": "user"
        }
        assistant = {
            "content": f"{answers}",
            "role": "assistant"
        }
        convos.append([tool_user, ques_user, assistant])

    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
    return {"text": texts}

# Apply the formatting on dataset
dataset = dataset.map(formatting_prompts_func, batched = True,)

7、定义训练参数

TrainingArguments 设置定义了用于微调模型的超参数和日志配置,有助于通过控制良好的步骤保持高效的训练。每个参数在优化模型行为和有效监控进度方面都发挥着作用。

from transformers import TrainingArguments

args = TrainingArguments(
        per_device_train_batch_size = 8,  # Controls the batch size per device
        gradient_accumulation_steps = 2,  # Accumulates gradients to simulate a larger batch
        warmup_steps = 5,
        learning_rate = 2e-4,             # Sets the learning rate for optimization
        num_train_epochs = 3,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        optim = "adamw_8bit",
        weight_decay = 0.01,              # Regularization term for preventing overfitting
        lr_scheduler_type = "linear",     # Chooses a linear learning rate decay
        seed = 3407,                        
        output_dir = "outputs",             
        report_to = "wandb",              # Enables Weights & Biases (W&B) logging
        logging_steps = 1,                # Sets frequency of logging to W&B
        logging_strategy = "steps",       # Logs metrics at each specified step
        save_strategy = "no",               
        load_best_model_at_end = True,    # Loads the best model at the end
        save_only_model = False           # Saves entire model, not only weights
    )

8、使用 SFTTrainer 和 Unsloth 进行训练

SFTTrainer 配置为使用自定义标记、数据集预处理和内存优化进行监督微调。与 unsloth_train 的组合将允许 unsloth 优化梯度检查点,这对于处理长序列和减少内存使用至关重要。

from trl import SFTTrainer

trainer = SFTTrainer(
    model = model,
    processing_class = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,        # Can make training 5x faster for short sequences.
    args = args
)

此代码在训练开始时捕获初始 GPU 内存统计数据。

# Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

现在我们已经完成了设置,让我们开始训练我们的模型。

from unsloth import unsloth_train

trainer_stats = unsloth_train(trainer)  
print(trainer_stats)
wandb.finish()
wandb 结果

训练过程结束后,以下代码检查并比较最终内存使用情况,捕获专门用于 LoRA 训练的内存并计算内存百分比。

# Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

我们可以在权重和偏差上可视化训练指标和系统指标,例如内存使用率、训练时长、训练损失和准确率,以更好地了解我们的模型随时间的性能。

模型训练指标仪表板
系统使用详情

9、保存和部署模型

训练后,经过微调的模型将保存在本地并推送到 Hugging Face 的中心以供进一步访问和部署。但是,这只会保存 LoRA 适配器。

# Local saving
model.save_pretrained("<lora_model_name>") 
tokenizer.save_pretrained("<lora_model_name>")

# Online saving
model.push_to_hub("<hf_username/lora_model_name>", token = hf_token)
tokenizer.push_to_hub("<hf_username/lora_model_name>", token = hf_token) 

要将 LoRA 适配器与基础模型合并,并将模型保存为 16 位精度,以优化 vLLM 性能,请使用如下代码合并为 16 位:

# Merge to 16bit
model.save_pretrained_merged("<model_name>", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("<hf_username/model_name>", tokenizer, save_method = "merged_16bit", token = hf_token)

10、微调模型评估

我们首先加载微调模型,该模型可以保存在本地磁盘上,也可以从 Hugging Face 检索,以及标记器。

# Local saving
model.save_pretrained("<lora_model_name>") 
tokenizer.save_pretrained("<lora_model_name>")

# Online saving
model.push_to_hub("<hf_username/lora_model_name>", token = hf_token)
tokenizer.push_to_hub("<hf_username/lora_model_name>", token = hf_token) 

我们现在可以定义专为数据检索而设计的实用函数,以增强用户体验。在本教程中,我们将使用以下函数用于演示目的:

  • get_current_date:返回当前日期,格式为“YYYY-MM-DD”。
  • get_current_weather:使用 OpenWeatherMap API 检索指定位置的天气数据。
  • celsius_to_fahrenheit:将温度从摄氏度转换为华氏度。
  • get_nasa_picture_of_the_day:获取有关 NASA 每日图像的详细信息。
  • get_stock_price:使用来自 Alpha Vantage 的数据提供指定股票代码和日期的股票价格。
import re
import json
import requests
from datetime import datetime
import nasapy


WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
NASA_API_KEY = os.getenv("NASA_API_KEY")
STOCK_API_KEY = os.getenv("STOCK_API_KEY")


def get_current_date() -> str:
    """
    Fetches the current date in the format YYYY-MM-DD.
    Returns:
        str: A string representing the current date.
    """
    print("Getting the current date")
    
    try:
        current_date = datetime.now().strftime("%Y-%m-%d")
        return current_date
    except Exception as e:
        print(f"Error fetching current date: {e}")
        return "NA"
    
    
def get_current_weather(location: str) -> dict:
    """
    Fetches the current weather for a given location (default: San Francisco).
    Args:
        location (str): The name of the city for which to retrieve the weather information.
    Returns:
        dict: A dictionary containing weather information such as temperature, weather description, and humidity.
    """
    print(f"Getting current weather for {location}")
    
    try:
        weather_url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={WEATHER_API_KEY}&units=metric"
        weather_data = requests.get(weather_url)
        data = weather_data.json()
        weather_description = data["weather"][0]["description"]
        temperature = data["main"]["temp"]
        humidity = data["main"]["humidity"]
        return {
            "description": weather_description,
            "temperature": temperature,
            "humidity": humidity
        }
    except Exception as e:
        print(f"Error fetching weather data: {e}")
        return {"weather": "NA"}
    
    
def celsius_to_fahrenheit(celsius: float) -> float:
    """
    Converts a temperature from Celsius to Fahrenheit.
    
    Args:
        celsius (float): Temperature in degrees Celsius.
        
    Returns:
        float: Temperature in degrees Fahrenheit.
    """
    print(f"Converting {celsius}°C to Fahrenheit")
    
    try:
        fahrenheit = (celsius * 9/5) + 32
        return fahrenheit
    except Exception as e:
        print(f"Error converting temperature: {e}")
        return None
    
    
def get_nasa_picture_of_the_day(date: str) -> dict:
    """
    Fetches NASA's Picture of the Day information for a given date.
    
    Args:
        date (str): The date for which to retrieve the picture in 'YYYY-MM-DD' format.
        
    Returns:
        dict: A dictionary containing the title, explanation, and URL of the image or video.
    """
    print(f"Getting NASA's Picture of the Day for {date}")
    
    try:
        nasa = nasapy.Nasa(key = NASA_API_KEY)
        apod = nasa.picture_of_the_day(date = date, hd=True)
        title = apod.get("title", "No Title")
        explanation = apod.get("explanation", "No Explanation")
        url = apod.get("url", "No URL")
        return {
            "title": title,
            "explanation": explanation,
            "url": url
        }
    except Exception as e:
        print(f"Error fetching NASA's Picture of the Day: {e}")
        return {"error": "Unable to fetch NASA Picture of the Day"}
    
    
def get_stock_price(ticker: str, date: str) -> tuple[str, str]:
    """
    Retrieves the lowest and highest stock prices for a given ticker and date.
    Args:
        ticker (str): The stock ticker symbol, e.g., "IBM".
        date (str): The date in "YYYY-MM-DD" format for which you want to get stock prices.
    Returns:
        tuple: A tuple containing the low and high stock prices on the given date, or ("none", "none") if not found.
    """
    print(f"Getting stock price for {ticker} on {date}")
    try:
        stock_url = f"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={ticker}&apikey={STOCK_API_KEY}"
        stock_data = requests.get(stock_url)
        stock_low = stock_data.json()["Time Series (Daily)"][date]["3. low"]
        stock_high = stock_data.json()["Time Series (Daily)"][date]["2. high"]
        return stock_low, stock_high
    except Exception as e:
        print(f"Error fetching stock data: {e}")
        return "none", "none"
    
    
available_function_calls = {"get_current_date": get_current_date, "get_current_weather": get_current_weather, "celsius_to_fahrenheit": celsius_to_fahrenheit,
                      "get_nasa_picture_of_the_day": get_nasa_picture_of_the_day, "get_stock_price": get_stock_price}

接下来,我们将创建一个可用函数列表及其函数定义。以下代码定义了可用函数的元数据,包括它们的名称、描述和必需参数。这对于将函数集成到类似聊天的界面中至关重要,在该界面中,模型可以根据用户查询了解要调用哪些函数。

functions = [
    {
        "name": "get_current_date",
        "description": "Fetches the current date in the format YYYY-MM-DD.",
        "parameters": {
            "type": "object",
            "properties": {},
            "required": [],
        },
    },
    {
        "name": "get_current_weather",
        "description": "Get the current weather",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and country code, e.g. San Francisco, US",
                }
            },
            "required": ["location"],
        },
    },
    {
        "name": "celsius_to_fahrenheit",
        "description": "Converts a temperature from Celsius to Fahrenheit.",
        "parameters": {
            "type": "object",
            "properties": {
                "celsius": {
                    "type": "number",
                    "description": "Temperature in degrees Celsius.",
                }
            },
            "required": ["celsius"],
        }
    },
    {
        "name": "get_nasa_picture_of_the_day",
        "description": "Fetches NASA's Picture of the Day information for a given date.",
        "parameters": {
            "type": "object",
            "properties": {
                "date": {
                    "type": "string",
                    "description": "Date in YYYY-MM-DD format for which to retrieve the picture.",
                }
            },
            "required": ["date"],
        },
    },
    {
        "name": "get_stock_price",
        "description": "Retrieves the lowest and highest stock price for a given ticker symbol and date. The ticker symbol must be a valid symbol for a publicly traded company on a major US stock exchange like NYSE or NASDAQ. The tool will return the latest trade price in USD.",
        "parameters": {
            "type": "object",
            "properties": {
                "ticker": {
                    "type": "string",
                    "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.",
                },
                "date": {
                    "type": "string",
                    "description": "Date in YYYY-MM-DD format",
                }
            },
            "required": ["ticker", "date"],
        },
    }
]


available_tools_list = {
    "functions_str": [json.dumps(x) for x in functions],
}

在此代码段中,我们指定用户问题并定义一组结构化的聊天消息,以及可用函数的 JSON 列表,供模型处理的标记器聊天模板中使用。

query = "What is the current weather at the headquarters of IBM? Also, can you provide the stock prices for the company on October 29, 2024?"

chat = [
    {"role":"system","content": f"You are a helpful assistant with access to the following function calls. Your task is to produce a sequence of function calls necessary to generate response to the user utterance. Use the following function calls as required.\n{available_tools_list}"},
    {"role": "user", "content": query }
]

然后,模型根据用户查询的意图确定适当的函数调用,如生成的响应中的函数调用名称所示。

inputs = tokenizer.apply_chat_template(
    chat,
    tokenize = True,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
).to("cuda")

outputs = model.generate(input_ids = inputs, max_new_tokens = 1024, use_cache = True)
response = tokenizer.batch_decode(outputs)[0]
print(response)

我们还可以利用 TextStreamer类来流式传输生成的文本输出,从而实现实时响应流式传输。

text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 512, use_cache = True, pad_token_id = tokenizer.eos_token_id)
流式响应

现在我们的模型已经告诉我们要调用哪些函数以及使用哪些参数,我们可以执行它们并将其输出传回 LLM,以便它可以生成最终答案返回给用户。

为了有效地执行这个函数,我们将从模型的输出中提取相关参数,确保我们拥有无缝执行所需的所有必要细节。这种方法使我们能够根据用户输入动态利用所选函数,从而增强整体交互体验。

def extract_content(text):
    # Define the regex pattern to extract the content
    pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)<\|eot_id\|>"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return None  

parsed_response = json.loads(extract_content(response))
print(parsed_response)
函数调用结果

以下代码通过执行必要的函数调用来处理解析后的响应,以根据用户查询收集信息。每个函数调用的结果都会附加到聊天历史记录中,系统消息也会更新以反映其当前状态。然后再次提示模型根据从函数调用中收集的信息生成最终响应。

if parsed_response:
    new_system_content = "You are a helpful assistant. Answer the user query based on the response of the specific function call or tool provided to you as context. Generate a precise answer for given user query, synthesizing the provided information."
    
    for res in parsed_response:
        obtained_function = res.get("name")
        arguments = res.get("arguments")
        function_description = next(item['description'] for item in functions if item['name'] == obtained_function)
        function_to_call = available_function_calls[obtained_function]
        response = function_to_call(**arguments)
        print(response)
        
        chat.append({
            "role": "tool",
            "content": f"The tool - '{obtained_function}' with the function definition - '{function_description}' and function arguments -'{arguments}' yielded the following response: {response}\n."
        })

        for message in chat:
            if message['role'] == 'system':
                message['content'] = new_system_content
                
    inputs = tokenizer.apply_chat_template(
        chat,
        tokenize = True,
        add_generation_prompt = True, 
        return_tensors = "pt").to("cuda")
    text_streamer = TextStreamer(tokenizer, skip_prompt = True)
    _ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 512, use_cache = True, pad_token_id = tokenizer.eos_token_id)
else:
    print("No function call found in the response")
最终 LLM 答案


11、设置 vLLM 进行快速推理

此代码通过加载我们保存的模型并使用指定的参数对其进行初始化,配置 vLLM 框架以实现高吞吐量和内存高效的推理。

from vllm import LLM
from vllm.sampling_params import SamplingParams

model_name = "<hf_username/model_name>"
sampling_params = SamplingParams(max_tokens=768)

llm = LLM(
    model=model_name,
    max_model_len=2048,
    tokenizer_mode="auto",
    tensor_parallel_size=1,
    enforce_eager=True,
    gpu_memory_utilization=0.95
)

llm_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant with access to the following function calls. Your task is to produce a sequence of function calls necessary to generate response to the user utterance. Use the following function calls as required.
{available_tools_list}<|eot_id|><|start_header_id|>user<|end_header_id|>

{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

input_prompt = llm_prompt.format(available_tools_list=available_tools_list, query=query)

output = llm.generate([input_prompt], sampling_params)
generated_text = output[0].outputs[0].text
print(f"Generated text: {generated_text!r}")
使用 vLLM 引擎运行离线推理

12、结束语

使用 Unsloth 和 LoRA 对 Llama-3.1–8B 模型进行微调,可以有效适应自定义域和特定任务,同时优化资源使用率。使用 LoRA 等技术不仅可以提高微调效率,还可以减少内存消耗,使其成为各种应用的实用选择。

此外,通过结合权重和偏差 (WandB) 进行实验跟踪,你可以简化工作流程并获得有关微调过程的宝贵见解。利用 vLLM 可确保高吞吐量和内存高效的模型服务,从而在实际场景中实现稳健的性能。

通过遵循本指南中概述的策略,你可以成功地微调模型以满足您的独特需求,最终以最少的资源获得高质量的结果。


原文链接:Fine-Tuning Llama-3.1-8B for Function Calling using LoRA

汇智网翻译整理,转载请标明出处

Tags