大模型微调改善SQL生成

在使用LLM 执行 SQL 生成任务时,通常难以生成准确的 SQL 查询,为此使用 QLoRA 在自定义数据集上微调大模型是一种有效的方法。

大模型微调改善SQL生成

在使用Mistral 7B LLM 执行 SQL 生成任务时,我遇到了挑战,尤其是在处理公司的数据库时。这些模型通常难以生成准确的 SQL 查询,即使在上下文中提供了数据库架构和表关系。为了应对这一挑战,使用 QLoRA 在针对特定数据库架构定制的自定义数据集上微调 7B 模型是一种有效的方法。

在本文中,我将引导你完成微调 7B 模型以更有效地处理 SQL 生成任务的过程,以及如何将微调后的模型集成到基于 LangChain 的应用程序中以实现实时数据库交互。

1、概述

在深入研究细节之前,让我们概述一下我们将在本指南中遵循的关键步骤:

  • 根据数据库架构准备自定义数据集。
  • 使用 QLoRA 技术微调 7B 模型。
  • 评估微调模型的性能。
  • 将模型集成到 LangChain 应用程序中以进行基于 SQL 的数据库交互。

通过遵循本指南,你将能够使用经过微调的 Mistral 7B 模型为 SQL 数据库构建问答应用程序,该模型针对根据你的特定数据库架构生成 SQL 查询进行了优化。

2、准备你的自定义数据集

要有效地微调模型,你需要一个能够反映数据库结构的高质量数据集。

让我们考虑一个包含3个表的简单客户管理数据库:

CREATE TABLE customer (
    customer_key INT PRIMARY KEY,
    source VARCHAR(50),
    full_name VARCHAR(100),
    created_date DATETIME,
    updated_date DATETIME,
    gender VARCHAR(10),
    dateofbirth DATE
);

CREATE TABLE address (
    address_key INT PRIMARY KEY,
    customer_key INT,
    street_address VARCHAR(200),
    city VARCHAR(100),
    state VARCHAR(50),
    postal_code VARCHAR(20),
    country VARCHAR(50),
    is_primary BOOLEAN,
    created_date DATETIME,
    updated_date DATETIME,
    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)
);

CREATE TABLE contact (
    contact_key INT PRIMARY KEY,
    customer_key INT,
    email VARCHAR(100),
    phone VARCHAR(20),
    created_date DATETIME,
    updated_date DATETIME,
    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)
);

2.1 创建文本到 SQL 样本

要生成用于微调的数据集,请使用 Claude Sonnet 或其他 LLM 创建文本到 SQL 的样本。以下是可用于指导模型创建 SQL 查询的提示格式:

Sample Format:
{
  "instruction": "I want you to act as a SQL terminal in front of an example database. You need only to return the SQL command to me. Below is an instruction that describes a task. Write a response that appropriately completes the request.
## Instruction:
[Database description]
",
  "input": "### Input:
[Natural language question]
### Response:",
  "output": "[Corresponding SQL query]"
}
Sample Data Point:
    {
      "instruction": "You are a powerful text-to-SQL model. Your task is to generate SQL queries based on the following schema for a customer database:\n\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY,\n    source VARCHAR(50),\n    full_name VARCHAR(100),\n    created_date DATETIME,\n    updated_date DATETIME,\n    gender VARCHAR(10),\n    dateofbirth DATE\n);\n\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT,\n    street_address VARCHAR(200),\n    city VARCHAR(100),\n    state VARCHAR(50),\n    postal_code VARCHAR(20),\n    country VARCHAR(50),\n    is_primary BOOLEAN,\n    created_date DATETIME,\n    updated_date DATETIME,\n    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)\n);\n\nCREATE TABLE contact (\n    contact_key INT PRIMARY KEY,\n    customer_key INT,\n    email VARCHAR(100),\n    phone VARCHAR(20),\n    created_date DATETIME,\n    updated_date DATETIME,\n    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)\n);",
      "input": "List the full names of customers who have both an email and a phone number.",
      "output": "SELECT DISTINCT c.full_name\nFROM customer c\nJOIN contact ct ON c.customer_key = ct.customer_key\nWHERE ct.email IS NOT NULL AND ct.phone IS NOT NULL;"
    }
The database contains three tables: customer, address, and contact.

Table 'customer' has columns:
customer_key (INT, primary key)
source (VARCHAR(50))
full_name (VARCHAR(100))
created_date (DATETIME)
updated_date (DATETIME)
gender (VARCHAR(10))
dateofbirth (DATE)

Table 'address' has columns:
address_key (INT, primary key)
customer_key (INT, foreign key referencing customer(customer_key))
street_address (VARCHAR(200))
city (VARCHAR(100))
state (VARCHAR(50))
postal_code (VARCHAR(20))
country (VARCHAR(50))
is_primary (BOOLEAN)
created_date (DATETIME)
updated_date (DATETIME)

Table 'contact' has columns:
contact_key (INT, primary key)
customer_key (INT, foreign key referencing customer(customer_key))
email (VARCHAR(100))
phone (VARCHAR(20))
created_date (DATETIME)
updated_date (DATETIME)

Please generate 100 samples in JSON file based on the provided database schema and example. For each sample, ensure that:
Instruction: include only the necessary table definitions in the instruction based on the SQL query in the output
Input: Contains a natural language question about the data.
Output: Provides the corresponding SQL query that answers the question. 
The questions should cover topics such as data analysis, aggregation, address searches, customer searches, contact searches, and reporting.

重复此过程,生成约 200 到 500 个样本,涵盖各种 SQL 任务,包括客户查询、地址查找和数据聚合。这些样本应以 JSON 格式保存,并分别拆分为 train.json 和 eval.json 以进行训练和评估。

数据集示例:

[
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about the customer database based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY,\n    full_name VARCHAR(100),\n    dateofbirth DATE\n);",
    "input": "List the full names of all customers born after January 1, 1990.",
    "output": "SELECT full_name FROM customer WHERE dateofbirth > '1990-01-01';"
  },
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about the address data based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT,\n    city VARCHAR(100)\n);",
    "input": "Find all unique cities where customers reside.",
    "output": "SELECT DISTINCT city FROM address;"
  },
    {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about customer emails based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY\n);\n\nCREATE TABLE contact (\n    contact_key INT PRIMARY KEY,\n    customer_key INT,\n    email VARCHAR(100)\n);",
    "input": "Find customers who have not provided an email address.",
    "output": "SELECT c.customer_key FROM customer c LEFT JOIN contact ct ON c.customer_key = ct.customer_key WHERE ct.email IS NULL;"
  },
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about customers without addresses based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY\n);\n\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT\n);",
    "input": "List customer keys of customers who have no address on file.",
    "output": "SELECT c.customer_key FROM customer c LEFT JOIN address a ON c.customer_key = a.customer_key WHERE a.customer_key IS NULL;"
  },
...
]

3、微调模型

数据集准备就绪后,下一步是使用 QLoRA 微调Mistral 7B 模型。有几种方法可用于微调模型:

  • Hugging Face TRL 和 SFTTrainer
  • DB-GPT-Hub:在公共文本到 SQL 数据集(如 Spider)上微调模型的不错选择。
  • LitGPT:用于快速微调、预训练和部署 LLM 的轻量级高效框架。

对于我们的用例,我们将使用 LitGPT 在客户数据库架构上微调我们的 7B 模型。

安装 LitGPT:

pip install litgpt

下载模型权重:

litgpt download mistralai/Mistral-7B-Instruct-v0.3 --access_token=xxxxxx

使用 4 位量化运行微调过程:

litgpt finetune_lora \
    checkpoints/mistralai/Mistral-7B-Instruct-v0.3 \
    --data JSON \
    --data.json_path train.json \
    --out_dir finetuned \
    --precision bf16-true \
    --quantize "bnb.nf4" \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --train.global_batch_size 4 \
    --train.micro_batch_size 1 \
    --train.max_steps 1000 \
    --train.save_interval 200 \
    --eval.interval 50 \
    --train.lr_warmup_steps 100 \
    --train.max_seq_length 2048 \
    --optimizer.learning_rate 2e-4 \
    --optimizer.weight_decay 0.01 \
    --optimizer.betas 0.9 \
    --data.val_split_fraction 0.1

有关更详细的说明,请参阅官方 LitGPT 文档

4、评估微调模型

微调后,评估模型的性能至关重要。我们将使用 Token Match 分数指标来实现此目的:

使用前面创建的 evolve.json 文件,其中包含示例 SQL 查询。

将合并后的权重从 LitGPT ( /finetuned/final/lit_model.pth) 转换为 Hugging Face Transformers 格式。

litgpt convert_from_litgpt finetuned/final out/hf-mistral-7b/converted

运行评估脚本:

import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import re
import sqlparse
from sklearn.metrics import accuracy_score

login(token="XXXXXXX")

# Load the fine-tuned LoRA model
model_path = "out/hf-mistral-7b/converted"
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
state_dict = torch.load(f"{model_path}/model.pth")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", state_dict=state_dict)

model.to("cuda:0")

# Load the evaluation dataset
with open("evaluate.json", "r") as f:
    eval_data = json.load(f)

def normalize_sql(sql):
    # Remove comments
    sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)

    sql = ' '.join(sql.split())

    parsed = sqlparse.parse(sql)[0]
    return str(parsed).lower() 

def exact_match_score(prediction, reference):
    return normalize_sql(prediction) == normalize_sql(reference)

def token_match_score(prediction, reference):
    pred_tokens = set(re.findall(r'\b\w+\b', normalize_sql(prediction)))
    ref_tokens = set(re.findall(r'\b\w+\b', normalize_sql(reference)))
    return len(pred_tokens.intersection(ref_tokens)) / len(ref_tokens) if ref_tokens else 0

def evaluate_model(model, tokenizer, eval_data):
    exact_matches = []
    token_match_scores = []

    for item in eval_data:
        instruction = item["instruction"]   
        input_question = item.get("input", "")   
        expected_output = item["output"]  

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": input_question},
        ]

        encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", padding_side='left')

        model_inputs = encodeds.to("cuda:0")

        outputs = model.generate(model_inputs, max_new_tokens=150)
        
        decoded = tokenizer.batch_decode(outputs[:, model_inputs.shape[1]:], skip_special_tokens=True)

        sql_query = decoded[0]

        print(f"Instruction: {instruction}")
        print(f"Input Question: {input_question}")
        print(f"Expected Output: {expected_output}")
        print(f"Generated Output: {sql_query}")
        print("=" * 50)

        # Compute metrics
        exact_matches.append(exact_match_score(sql_query, expected_output))
        token_match_scores.append(token_match_score(sql_query, expected_output))

    avg_token_match_score = sum(token_match_scores) / len(token_match_scores)

    print(f"Average Token Match Score: {avg_token_match_score:.4f}")

evaluate_model(model, tokenizer, eval_data)

此评估方法可以让你更好地了解模型在文本到 SQL 任务上的性能。

如你所见,0.8786 的得分对于文本到 SQL 模型来说相当不错。

5、使用 LangChain 构建数据库交互 RAG

成功微调和评估模型后,我们现在可以将其集成到 LangChain 应用程序中以构建数据库交互应用程序。

首先安装 LangChain 和 llama.cpp:

set FORCE_CMAKE=1 && set CMAKE_ARGS=-DGGML_CUDA=on && pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/124

将 Hugging Face 模型转换为 GGUF 格式:

python convert_hf_to_gguf.py /path/to/hf-model --outfile custom-mistral-7b.gguf --outtype f16

量化模型以减小尺寸(可选):

./llama-quantize ./custom-mistral-7b.gguf ./custom-mistral-7b-Q5_K_M.gguf Q5_K_M

我建议使用 Q5_K_M,因为它可以保留模型的大部分性能。或者,如果你想节省一些内存,您可以选择 Q4_K_M。

安装 LangChain 和所需的库。

实现 LangChain SQL 链:

import re
from langchain.sql_database import SQLDatabase
from langchain_community.llms import LlamaCpp
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts.chat import ChatPromptTemplate


# Initialize the LlamaCpp language model with specified parameters
llm = LlamaCpp(
    model_path="./custom-mistral-7b-Q5_K_M.gguf",  # Path to the model file
    max_tokens=2048,  # Maximum number of tokens in the response
    n_ctx=6144,  # Context size
    verbose=True, 
    temperature=0,  
)

# Define the dataset and SQLAlchemy connection URL
dataset = "customer"
sqlalchemy_url = (
    f"postgresql://db_user:db_pass@db_host:5432"  # Replace with actual credentials and host
)

# Initialize the SQLDatabase object with specified schema and tables
db = SQLDatabase.from_uri(
    sqlalchemy_url,
    schema=dataset,
    include_tables=['customer', 'address', 'contact']  # Tables to include in the database
)

# Create the SQL query generation chain using the language model and database
gen_query = create_sql_query_chain(llm, db)

# Convert datetime objects in strings to a specific format
def convert_dates(obj):
    response_str = re.sub(
        r'datetime\.date\((\d+),\s*(\d+),\s*(\d+)\)',
        r"'\1-\2-\3'",
        obj
    )
    response_str = re.sub(
        r'datetime\.datetime\((\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+)(?:,\s*(\d+))?\)',
        r"'\1-\2-\3 \4:\5:\6.\7'",
        response_str
    )
    return response_str

# Route the SQL query based on its content
def route(sql_query: str):
    logging.info(f"Routing query: {sql_query}")  # Log the query for debugging
    if sql_query.get("query") == "I don't know":
        logging.warning("Unknown query detected.")  # Warn if the query is unknown
        return sql_query  # Return the original query
    else:
        return db_opt_chain  # Route to the database operation chain

# Handle cases where the agent responds with "I don't know"
def handle_dont_know(result):
    if isinstance(result, dict) and result.get("query") == "I don't know":
        return "I can only provide information related to our customer data."  # Custom response
    return result  # Return the original result if not "I don't know"

# Custom function to execute the SQL query and process the result
def custom_execute_query_runnable(result: dict) -> dict:
    return {
        **result,  # Include existing result data
        'result': convert_dates(
            db.run_no_throw(command=result["query"], include_columns=True)  # Execute the query without throwing exceptions
        )
    }

# Template for generating the final natural language response
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

# Define the prompt for generating the response
prompt_response = ChatPromptTemplate.from_template(template)

# Define the database operation chain with the custom execution function and response prompt
db_opt_chain = (
    RunnableLambda(custom_execute_query_runnable)  # Execute the query
    | answer_prompt  # Generate the answer based on the query result
    | llm  # Use the language model to format the answer
)

# Combine all components into the full execution chain
full_chain = (
    RunnablePassthrough().assign(query=gen_query)  # Pass through the query generation
    | RunnableLambda(route)  # Route the query appropriately
    | RunnableLambda(handle_dont_know)  # Handle "I don't know" responses
    | StrOutputParser()  # Parse the final output as a string
)

# Example user question to invoke the chain
user_question = 'Get the minimum and maximum age of customers'
full_chain.invoke({"question": user_question})  # Execute the chain with the user question

注意:本指南中提供的示例代码仅供参考。你需要调整代码以适合你的特定用例,例如修改数据库连接字符串、架构定义或根据你的数据和基础架构要求调整模型参数。

此实现会创建一个链,该链使用经过微调的 LLM 生成 SQL,针对数据库执行查询,然后再次使用 LLM 解释和总结结果。

6、结束语

通过使用针对特定数据库架构定制的自定义数据集对 7B LLM 进行微调,您可以大大增强其 SQL 生成功能。与 LangChain 结合使用时,即使使用较小的语言模型,您也可以构建用于数据库交互的问答应用程序。

请记住,微调模型的有效性取决于训练数据集的质量和多样性。随着时间的推移,不断完善数据集和微调过程将带来更好的结果。


原文链接:Improving Text-to-SQL with a Fine-Tuned 7B LLM for DB Interactions

汇智网翻译整理,转载请标明出处