大模型微调改善SQL生成

在使用Mistral 7B LLM 执行 SQL 生成任务时,我遇到了挑战,尤其是在处理公司的数据库时。这些模型通常难以生成准确的 SQL 查询,即使在上下文中提供了数据库架构和表关系。为了应对这一挑战,使用 QLoRA 在针对特定数据库架构定制的自定义数据集上微调 7B 模型是一种有效的方法。

在本文中,我将引导你完成微调 7B 模型以更有效地处理 SQL 生成任务的过程,以及如何将微调后的模型集成到基于 LangChain 的应用程序中以实现实时数据库交互。

1、概述

在深入研究细节之前,让我们概述一下我们将在本指南中遵循的关键步骤:

  • 根据数据库架构准备自定义数据集。
  • 使用 QLoRA 技术微调 7B 模型。
  • 评估微调模型的性能。
  • 将模型集成到 LangChain 应用程序中以进行基于 SQL 的数据库交互。

通过遵循本指南,你将能够使用经过微调的 Mistral 7B 模型为 SQL 数据库构建问答应用程序,该模型针对根据你的特定数据库架构生成 SQL 查询进行了优化。

2、准备你的自定义数据集

要有效地微调模型,你需要一个能够反映数据库结构的高质量数据集。

让我们考虑一个包含3个表的简单客户管理数据库:

CREATE TABLE customer (
    customer_key INT PRIMARY KEY,
    source VARCHAR(50),
    full_name VARCHAR(100),
    created_date DATETIME,
    updated_date DATETIME,
    gender VARCHAR(10),
    dateofbirth DATE
);

CREATE TABLE address (
    address_key INT PRIMARY KEY,
    customer_key INT,
    street_address VARCHAR(200),
    city VARCHAR(100),
    state VARCHAR(50),
    postal_code VARCHAR(20),
    country VARCHAR(50),
    is_primary BOOLEAN,
    created_date DATETIME,
    updated_date DATETIME,
    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)
);

CREATE TABLE contact (
    contact_key INT PRIMARY KEY,
    customer_key INT,
    email VARCHAR(100),
    phone VARCHAR(20),
    created_date DATETIME,
    updated_date DATETIME,
    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)
);

2.1 创建文本到 SQL 样本

要生成用于微调的数据集,请使用 Claude Sonnet 或其他 LLM 创建文本到 SQL 的样本。以下是可用于指导模型创建 SQL 查询的提示格式:

Sample Format:
{
  "instruction": "I want you to act as a SQL terminal in front of an example database. You need only to return the SQL command to me. Below is an instruction that describes a task. Write a response that appropriately completes the request.
## Instruction:
[Database description]
",
  "input": "### Input:
[Natural language question]
### Response:",
  "output": "[Corresponding SQL query]"
}
Sample Data Point:
    {
      "instruction": "You are a powerful text-to-SQL model. Your task is to generate SQL queries based on the following schema for a customer database:\n\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY,\n    source VARCHAR(50),\n    full_name VARCHAR(100),\n    created_date DATETIME,\n    updated_date DATETIME,\n    gender VARCHAR(10),\n    dateofbirth DATE\n);\n\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT,\n    street_address VARCHAR(200),\n    city VARCHAR(100),\n    state VARCHAR(50),\n    postal_code VARCHAR(20),\n    country VARCHAR(50),\n    is_primary BOOLEAN,\n    created_date DATETIME,\n    updated_date DATETIME,\n    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)\n);\n\nCREATE TABLE contact (\n    contact_key INT PRIMARY KEY,\n    customer_key INT,\n    email VARCHAR(100),\n    phone VARCHAR(20),\n    created_date DATETIME,\n    updated_date DATETIME,\n    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)\n);",
      "input": "List the full names of customers who have both an email and a phone number.",
      "output": "SELECT DISTINCT c.full_name\nFROM customer c\nJOIN contact ct ON c.customer_key = ct.customer_key\nWHERE ct.email IS NOT NULL AND ct.phone IS NOT NULL;"
    }
The database contains three tables: customer, address, and contact.

Table 'customer' has columns:
customer_key (INT, primary key)
source (VARCHAR(50))
full_name (VARCHAR(100))
created_date (DATETIME)
updated_date (DATETIME)
gender (VARCHAR(10))
dateofbirth (DATE)

Table 'address' has columns:
address_key (INT, primary key)
customer_key (INT, foreign key referencing customer(customer_key))
street_address (VARCHAR(200))
city (VARCHAR(100))
state (VARCHAR(50))
postal_code (VARCHAR(20))
country (VARCHAR(50))
is_primary (BOOLEAN)
created_date (DATETIME)
updated_date (DATETIME)

Table 'contact' has columns:
contact_key (INT, primary key)
customer_key (INT, foreign key referencing customer(customer_key))
email (VARCHAR(100))
phone (VARCHAR(20))
created_date (DATETIME)
updated_date (DATETIME)

Please generate 100 samples in JSON file based on the provided database schema and example. For each sample, ensure that:
Instruction: include only the necessary table definitions in the instruction based on the SQL query in the output
Input: Contains a natural language question about the data.
Output: Provides the corresponding SQL query that answers the question. 
The questions should cover topics such as data analysis, aggregation, address searches, customer searches, contact searches, and reporting.

重复此过程,生成约 200 到 500 个样本,涵盖各种 SQL 任务,包括客户查询、地址查找和数据聚合。这些样本应以 JSON 格式保存,并分别拆分为 train.json 和 eval.json 以进行训练和评估。

数据集示例:

[
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about the customer database based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY,\n    full_name VARCHAR(100),\n    dateofbirth DATE\n);",
    "input": "List the full names of all customers born after January 1, 1990.",
    "output": "SELECT full_name FROM customer WHERE dateofbirth > '1990-01-01';"
  },
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about the address data based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT,\n    city VARCHAR(100)\n);",
    "input": "Find all unique cities where customers reside.",
    "output": "SELECT DISTINCT city FROM address;"
  },
    {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about customer emails based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY\n);\n\nCREATE TABLE contact (\n    contact_key INT PRIMARY KEY,\n    customer_key INT,\n    email VARCHAR(100)\n);",
    "input": "Find customers who have not provided an email address.",
    "output": "SELECT c.customer_key FROM customer c LEFT JOIN contact ct ON c.customer_key = ct.customer_key WHERE ct.email IS NULL;"
  },
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about customers without addresses based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY\n);\n\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT\n);",
    "input": "List customer keys of customers who have no address on file.",
    "output": "SELECT c.customer_key FROM customer c LEFT JOIN address a ON c.customer_key = a.customer_key WHERE a.customer_key IS NULL;"
  },
...
]

3、微调模型

数据集准备就绪后,下一步是使用 QLoRA 微调Mistral 7B 模型。有几种方法可用于微调模型:

  • Hugging Face TRL 和 SFTTrainer
  • DB-GPT-Hub:在公共文本到 SQL 数据集(如 Spider)上微调模型的不错选择。
  • LitGPT:用于快速微调、预训练和部署 LLM 的轻量级高效框架。

对于我们的用例,我们将使用 LitGPT 在客户数据库架构上微调我们的 7B 模型。

安装 LitGPT:

pip install litgpt

下载模型权重:

litgpt download mistralai/Mistral-7B-Instruct-v0.3 --access_token=xxxxxx

使用 4 位量化运行微调过程:

litgpt finetune_lora \
    checkpoints/mistralai/Mistral-7B-Instruct-v0.3 \
    --data JSON \
    --data.json_path train.json \
    --out_dir finetuned \
    --precision bf16-true \
    --quantize "bnb.nf4" \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --train.global_batch_size 4 \
    --train.micro_batch_size 1 \
    --train.max_steps 1000 \
    --train.save_interval 200 \
    --eval.interval 50 \
    --train.lr_warmup_steps 100 \
    --train.max_seq_length 2048 \
    --optimizer.learning_rate 2e-4 \
    --optimizer.weight_decay 0.01 \
    --optimizer.betas 0.9 \
    --data.val_split_fraction 0.1

有关更详细的说明,请参阅官方 LitGPT 文档

4、评估微调模型

微调后,评估模型的性能至关重要。我们将使用 Token Match 分数指标来实现此目的:

使用前面创建的 evolve.json 文件,其中包含示例 SQL 查询。

将合并后的权重从 LitGPT ( /finetuned/final/lit_model.pth) 转换为 Hugging Face Transformers 格式。

litgpt convert_from_litgpt finetuned/final out/hf-mistral-7b/converted

运行评估脚本:

import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import re
import sqlparse
from sklearn.metrics import accuracy_score

login(token="XXXXXXX")

# Load the fine-tuned LoRA model
model_path = "out/hf-mistral-7b/converted"
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
state_dict = torch.load(f"{model_path}/model.pth")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", state_dict=state_dict)

model.to("cuda:0")

# Load the evaluation dataset
with open("evaluate.json", "r") as f:
    eval_data = json.load(f)

def normalize_sql(sql):
    # Remove comments
    sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)

    sql = ' '.join(sql.split())

    parsed = sqlparse.parse(sql)[0]
    return str(parsed).lower() 

def exact_match_score(prediction, reference):
    return normalize_sql(prediction) == normalize_sql(reference)

def token_match_score(prediction, reference):
    pred_tokens = set(re.findall(r'\b\w+\b', normalize_sql(prediction)))
    ref_tokens = set(re.findall(r'\b\w+\b', normalize_sql(reference)))
    return len(pred_tokens.intersection(ref_tokens)) / len(ref_tokens) if ref_tokens else 0

def evaluate_model(model, tokenizer, eval_data):
    exact_matches = []
    token_match_scores = []

    for item in eval_data:
        instruction = item["instruction"]   
        input_question = item.get("input", "")   
        expected_output = item["output"]  

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": input_question},
        ]

        encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", padding_side='left')

        model_inputs = encodeds.to("cuda:0")

        outputs = model.generate(model_inputs, max_new_tokens=150)
        
        decoded = tokenizer.batch_decode(outputs[:, model_inputs.shape[1]:], skip_special_tokens=True)

        sql_query = decoded[0]

        print(f"Instruction: {instruction}")
        print(f"Input Question: {input_question}")
        print(f"Expected Output: {expected_output}")
        print(f"Generated Output: {sql_query}")
        print("=" * 50)

        # Compute metrics
        exact_matches.append(exact_match_score(sql_query, expected_output))
        token_match_scores.append(token_match_score(sql_query, expected_output))

    avg_token_match_score = sum(token_match_scores) / len(token_match_scores)

    print(f"Average Token Match Score: {avg_token_match_score:.4f}")

evaluate_model(model, tokenizer, eval_data)

此评估方法可以让你更好地了解模型在文本到 SQL 任务上的性能。

如你所见,0.8786 的得分对于文本到 SQL 模型来说相当不错。

5、使用 LangChain 构建数据库交互 RAG

成功微调和评估模型后,我们现在可以将其集成到 LangChain 应用程序中以构建数据库交互应用程序。

首先安装 LangChain 和 llama.cpp:

set FORCE_CMAKE=1 && set CMAKE_ARGS=-DGGML_CUDA=on && pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/124

将 Hugging Face 模型转换为 GGUF 格式:

python convert_hf_to_gguf.py /path/to/hf-model --outfile custom-mistral-7b.gguf --outtype f16

量化模型以减小尺寸(可选):

./llama-quantize ./custom-mistral-7b.gguf ./custom-mistral-7b-Q5_K_M.gguf Q5_K_M

我建议使用 Q5_K_M,因为它可以保留模型的大部分性能。或者,如果你想节省一些内存,您可以选择 Q4_K_M。

安装 LangChain 和所需的库。

实现 LangChain SQL 链:

import re
from langchain.sql_database import SQLDatabase
from langchain_community.llms import LlamaCpp
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts.chat import ChatPromptTemplate


# Initialize the LlamaCpp language model with specified parameters
llm = LlamaCpp(
    model_path="./custom-mistral-7b-Q5_K_M.gguf",  # Path to the model file
    max_tokens=2048,  # Maximum number of tokens in the response
    n_ctx=6144,  # Context size
    verbose=True, 
    temperature=0,  
)

# Define the dataset and SQLAlchemy connection URL
dataset = "customer"
sqlalchemy_url = (
    f"postgresql://db_user:db_pass@db_host:5432"  # Replace with actual credentials and host
)

# Initialize the SQLDatabase object with specified schema and tables
db = SQLDatabase.from_uri(
    sqlalchemy_url,
    schema=dataset,
    include_tables=['customer', 'address', 'contact']  # Tables to include in the database
)

# Create the SQL query generation chain using the language model and database
gen_query = create_sql_query_chain(llm, db)

# Convert datetime objects in strings to a specific format
def convert_dates(obj):
    response_str = re.sub(
        r'datetime\.date\((\d+),\s*(\d+),\s*(\d+)\)',
        r"'\1-\2-\3'",
        obj
    )
    response_str = re.sub(
        r'datetime\.datetime\((\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+)(?:,\s*(\d+))?\)',
        r"'\1-\2-\3 \4:\5:\6.\7'",
        response_str
    )
    return response_str

# Route the SQL query based on its content
def route(sql_query: str):
    logging.info(f"Routing query: {sql_query}")  # Log the query for debugging
    if sql_query.get("query") == "I don't know":
        logging.warning("Unknown query detected.")  # Warn if the query is unknown
        return sql_query  # Return the original query
    else:
        return db_opt_chain  # Route to the database operation chain

# Handle cases where the agent responds with "I don't know"
def handle_dont_know(result):
    if isinstance(result, dict) and result.get("query") == "I don't know":
        return "I can only provide information related to our customer data."  # Custom response
    return result  # Return the original result if not "I don't know"

# Custom function to execute the SQL query and process the result
def custom_execute_query_runnable(result: dict) -> dict:
    return {
        **result,  # Include existing result data
        'result': convert_dates(
            db.run_no_throw(command=result["query"], include_columns=True)  # Execute the query without throwing exceptions
        )
    }

# Template for generating the final natural language response
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

# Define the prompt for generating the response
prompt_response = ChatPromptTemplate.from_template(template)

# Define the database operation chain with the custom execution function and response prompt
db_opt_chain = (
    RunnableLambda(custom_execute_query_runnable)  # Execute the query
    | answer_prompt  # Generate the answer based on the query result
    | llm  # Use the language model to format the answer
)

# Combine all components into the full execution chain
full_chain = (
    RunnablePassthrough().assign(query=gen_query)  # Pass through the query generation
    | RunnableLambda(route)  # Route the query appropriately
    | RunnableLambda(handle_dont_know)  # Handle "I don't know" responses
    | StrOutputParser()  # Parse the final output as a string
)

# Example user question to invoke the chain
user_question = 'Get the minimum and maximum age of customers'
full_chain.invoke({"question": user_question})  # Execute the chain with the user question

注意:本指南中提供的示例代码仅供参考。你需要调整代码以适合你的特定用例,例如修改数据库连接字符串、架构定义或根据你的数据和基础架构要求调整模型参数。

此实现会创建一个链,该链使用经过微调的 LLM 生成 SQL,针对数据库执行查询,然后再次使用 LLM 解释和总结结果。

6、结束语

通过使用针对特定数据库架构定制的自定义数据集对 7B LLM 进行微调,您可以大大增强其 SQL 生成功能。与 LangChain 结合使用时,即使使用较小的语言模型,您也可以构建用于数据库交互的问答应用程序。

请记住,微调模型的有效性取决于训练数据集的质量和多样性。随着时间的推移,不断完善数据集和微调过程将带来更好的结果。


原文链接:Improving Text-to-SQL with a Fine-Tuned 7B LLM for DB Interactions

汇智网翻译整理,转载请标明出处