用MLX微调医学诊断大模型
在我之前的文章中,讨论了使用 Apple MLX 框架微调大型语言模型 (LLM) 的基础知识及其具体用例,包括如何构建 LLM 的自定义版本。那篇文章重点介绍了使用低秩适配器 (LoRA) 微调 mistralai/Mistral-7B-Instruct-v0.2
LLM 以执行文本到 SQL 任务的具体用例,允许模型根据用户提示生成 SQL 查询。
在这篇文章中,我将探讨 LLM 微调用于医疗诊断预测的更高级用例。在这里,我使用医疗数据集微调了 mistralai/Mistral-7B-Instruct-v0.2
,其中包括疾病和症状的自然语言描述。通过使用此数据集训练 mistralai/Mistral-7B-Instruct-v0.2
模型,该模型学会根据输入症状预测潜在诊断。
微调过程是在配备 M2 芯片的 Apple Silicon Mac 上使用 LoRA 和 Apple MLX 框架进行的。微调后,这些自定义模型使用 Ollama 运行。
与本文相关的所有源代码都已在 GitLab 上发布。请克隆仓库以继续阅读本文。
1、LLM 微调
微调是调整预训练大型语言模型 (LLM) 的参数或权重以使其专门用于特定任务或领域的过程。
虽然像 GPT 这样的预训练语言模型具有广泛的通用语言知识,但它们往往缺乏专业领域的专业知识。微调通过在特定领域的数据上训练模型来克服这个问题,从而提高其针对目标应用的准确性和有效性。
此过程涉及将模型暴露于特定任务的示例,使其能够更深入地掌握领域的细微差别。这一关键步骤将通用语言模型转变为专用工具,从而释放 LLM 在特定领域或应用程序上的全部潜力。但是,微调 LLM 需要大量计算资源(例如 GPU)才能确保高效训练。
有各种可用的 LLM 微调技术,包括低秩适配器 (LoRA)、量化 LoRA (QLoRA)、参数高效微调 (PEFT)、DeepSpeed 和 ZeRO。有关 LLM 微调的更多信息,请点击此处。
在这篇文章中,我将讨论在 Apple MLX 框架中使用 LoRA 技术微调 LLM。LoRA 由微软研究人员团队于 2021 年首次推出,提供了一种参数高效的微调方法。与需要微调整个基础模型(可能耗时长且成本高昂)的传统方法不同,LoRA 添加了少量可训练参数,同时保持原始模型参数不变。
LoRA 的本质在于向模型添加适配器层,从而提高其效率和适应性。LoRA 不是合并全新的层,而是通过引入低秩矩阵来修改现有层的行为。这种方法引入的附加参数最少,因此与完全模型再训练相比,计算开销和内存使用量显著减少。
通过将调整重点放在特定的模型组件上,LoRA 保留了嵌入在原始权重中的基础知识,从而最大限度地降低了灾难性遗忘的风险。这种有针对性的调整不仅保持了模型的一般能力,而且还实现了快速迭代和特定任务的增强,使 LoRA 成为微调大型预训练模型的灵活且可扩展的解决方案。
有关 LoRA 和示例的更多信息,请在此处阅读更多信息。
2、RAG 与 LLM 微调
RAG(检索增强生成)通过提供对精选数据库的访问来增强 LLM,使其能够动态检索相关信息以生成响应。相比之下,微调涉及通过在特定的标记数据集上训练模型来调整模型的参数,以提高其在特定任务上的性能。微调会修改模型本身,而 RAG 会扩展模型可以访问的数据。
当你需要使用初始训练时不可用的数据来补充语言模型的提示时,请使用 RAG。这可以包括实时数据、用户特定数据或与提示相关的上下文信息。RAG 非常适合确保模型能够访问最新和最相关的数据。另一方面,微调最适合训练模型以更准确地理解和执行特定任务。
要了解有关用例以及如何在 RAG 和微调之间进行选择的更多信息,请阅读这篇文章。
3、使用 Apple MLX 进行 LLM 微调
长期以来,人们一直认为 ML 训练和推理只能在 Nvidia GPU 上执行。然而,随着 ML 框架 MLX 的发布,这种观点发生了变化,它支持在 Apple Silicon CPU/GPU 上进行 ML 训练和推理。由 Apple 开发的 MLX 库类似于 TensorFlow
和 PyTorch
,并支持 GPU 支持的任务。该库允许在 Apple Silicon(M 系列)芯片上对 LLM 进行微调。此外,MLX 支持使用 LoRA 方法进行 LLM 微调。我已经使用 MLX 和 LoRA 成功微调了几个 LLM,包括 Llama-3 和 Mistral。
3.1 用例
在这篇文章中,我将探讨 LLM 微调用于医学诊断预测的更高级用例。在这里,我使用医学数据集对 mistralai/Mistral-7B-Instruct-v0.2
进行微调,其中包括疾病和症状的自然语言描述。通过使用此数据集训练 mistralai/Mistral-7B-Instruct-v0.2
模型,该模型学会根据输入症状预测潜在诊断。
3.2 设置 MLX 和其他工具
首先,我需要安装 MLX 以及一组必需的工具。下面是我安装的工具列表,以及我如何设置和配置 MLX 环境。
# used repository called mlxm
❯❯ git clone https://gitlab.com/rahasak-labs/mlxm.git
❯❯ cd mlxm
# create and activate virtial enviroument
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
# install mlx
❯❯ pip install -U mlx-lm
# install other requried pythong pakcages
❯❯ pip install pandas
❯❯ pip install pyarrow
3.3 设置 Huggingface-CLI
我从 Hugging Face 获取 LLM(基础模型)和数据集。为此,我需要在 Hugging Face 上设置一个帐户并配置 huggingface-cli
命令行工具。
# setup account in hugging-face from here
https://huggingface.co/welcome
# create access token to read/write data from hugging-face through the cli
# this token required when login to huggingface cli
https://huggingface.co/settings/tokens
# setup hugginface-cli
❯❯ pip install huggingface_hub
❯❯ pip install "huggingface_hub[cli]"
# login to huggingface through cli
# it will ask the access token previously created
❯❯ huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
Setting a new token will erase the existing one.
To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/lambda.eranga/.cache/huggingface/token
Login successful
# once login the tokne will be saved in the ~/.cache/huggingface
❯❯ ls ~/.cache/huggingface
datasets
hub
token
3.4 准备数据集
MLX 要求数据采用特定格式。MLX 中讨论了三种主要格式: chat
、 completion
和 text
。你可以在此处阅读有关这些数据格式的更多信息。
对于此用例,我将使用 text
格式数据,它将 context
、 question
和 response
等信息组合在一个自然语言短语中。此格式需要生成一个数据集,其中包含一个包含所有相关信息的文本字段:上下文、问题和响应。
数据集生成在 LLM 的微调中起着至关重要的作用,因为它直接影响微调模型的准确性。可以采用各种技术来生成用于微调 LLM 的数据集。例如,这篇文章讨论了使用 LLM 和提示工程来生成数据集。
原始数据集的结构为 .csv
文件。为了使用 Apple MLX 进行微调,我将其转换为基于文本的格式。每个文本记录都将上下文、问题和响应信息组合成一个连贯的自然语言短语。
# original csv record
label,text
Psoriasis,"I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches."
# converted text type record
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Psoriasis."}
MLX 需要三组数据集:train
、 test
和 valid
,以便进行微调。数据文件应为 JSONL 格式。以下脚本将 CSV 文件中的数据转换为 JSONL 格式。在转换过程中,它将提示、问题和响应数据组合成一个自然语言文本短语,在一个文本字段中无缝捕获所有元素。
import pandas as pd
import json
import random
# load csv data
file_path = './s2d.csv'
df = pd.read_csv(file_path)
# create text type data
jsonl_data = []
for _, row in df.iterrows():
diagnosis = row['label']
symptoms = row['text']
prompt = f"You are a medical diagnosis expert. You will give patient symptoms: '{symptoms}'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with {diagnosis}."
jsonl_data.append({"text": prompt})
# shuffle the data
random.shuffle(jsonl_data)
# calculate split indices
total_records = len(jsonl_data)
train_split = int(total_records * 2 / 3)
test_split = int(total_records * 1 / 6)
# split the data
train_data = jsonl_data[:train_split]
test_data = jsonl_data[train_split:train_split + test_split]
valid_data = jsonl_data[train_split + test_split:]
# write to JSONL files
with open('train.jsonl', 'w') as train_file:
for entry in train_data:
train_file.write(json.dumps(entry) + '\n')
with open('test.jsonl', 'w') as test_file:
for entry in test_data:
test_file.write(json.dumps(entry) + '\n')
with open('valid.jsonl', 'w') as valid_file:
for entry in valid_data:
valid_file.write(json.dumps(entry) + '\n')
print("data successfully saved to train.jsonl, test.jsonl, and valid.jsonl")
我已将数据集放在数据目录中。随后,我运行脚本来生成训练、测试和有效数据集的 JSONL 格式文件。下面是生成的数据文件的结构。
# activate virtual env
❯❯ source .venv/bin/activate
# data directory
❯❯ ls -al data
prepare.py
s2d.csv
# generate jsonl files
❯❯ cd data
❯❯ python prepare.py
# generated files
❯❯ ls -ls
test.jsonl
train.jsonl
valid.jsonl
# train.jsonl
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing fatigue, difficulty walking, diarrhea, night sweats, tremors.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Influenza."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing difficulty breathing, weight loss, fever.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Pneumonia."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing dizziness, dry skin, rapid heartbeat, shortness of breath, vision changes.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Mumps."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing vomiting, shortness of breath, night sweats, rapid heartbeat.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Hepatitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing loss of appetite, tremors, fatigue, difficulty breathing, increased urination.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Scoliosis."}
# test.jsonl
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing difficulty walking, fever, loss of appetite, fatigue, shortness of breath, sore throat.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Depression."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing fever, rapid heartbeat, hair loss.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Asthma."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing hair loss, dizziness, rapid heartbeat.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Hepatitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing skin rash, abdominal pain, difficulty breathing.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Anxiety."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have lots of itchy spots on my skin, and sometimes they turn red or bumpy. There are also some weird patches that are different colors than the rest of my skin, and sometimes I get these weird bumps that look like little balls.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Fungal infection."}
# valid.jsonl
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing dizziness, increased urination, muscle pain.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Bronchitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing sore throat, rapid heartbeat, hair loss, loss of appetite.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Bronchitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing skin rash, vomiting, muscle pain, joint pain.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Gout."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been enduring frequent headaches, blurred vision, excessive appetite, a sore neck, anxiety, irritability, and digestive difficulties including indigestion and acid reflux.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Migraine."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing weight loss, vision changes, vomiting.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Tonsillitis."}
3.6 微调/训练 LLM
下一步是使用我之前准备的数据集,通过 MLX 微调 Mistral-7B LLM。最初,我使用 huggingface-cli 从 Hugging Face 下载了 mistralai/Mistral-7B-Instruct-v0.2
LLM。然后,我使用提供的数据集和 LoRA 训练了 LLM。LoRA 或低秩自适应涉及引入低秩矩阵来调整模型的行为,而无需进行大量重新训练,从而保留原始模型参数,同时实现高效、有针对性的自适应。
在配备 64GB RAM 和 30 个 GPU 的 Mac M2 上,训练过程大约需要 25 分钟来训练 LLM 并生成必要的适配器。
# download llm
❯❯ huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2
/Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/1296dc8fd9b21e6424c9c305c06db9ae60c03ace
# model is downloaded into ~/.cache/huggingface/hub/
❯❯ ls ~/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
blobs refs snapshots
# list all downloaded models from huggingface
❯❯ huggingface-cli scan-cache
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
-------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------
BAAI/bge-reranker-base model 1.1G 6 3 months ago 4 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--BAAI--bge-reranker-base
NousResearch/Meta-Llama-3-8B model 16.1G 14 2 months ago 5 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--NousResearch--Meta-Llama-3-8B
gpt2 model 2.9M 5 8 months ago 8 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--gpt2
infgrad/stella_en_1.5B_v5 model 240.7K 6 4 months ago 4 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--infgrad--stella_en_1.5B_v5
mistralai/Mistral-7B-Instruct-v0.2 model 29.5G 21 1 day ago 1 day ago main /Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
sentence-transformers/all-MiniLM-L6-v2 model 91.6M 11 8 months ago 8 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2
Done in 0.0s. Scanned 6 repo(s) for a total of 46.8G.
Got 1 warning(s) while scanning. Use -vvv to print details
# fine-tune llm
# --model - original model which download from huggin face
# --data data - data directory path with train.jsonl
# --batch-size 4 - batch size
# --lora-layers 16 - number of lora layers
# --iters 1000 - tranning iterations
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--data data \
--train \
--batch-size 4\
--lora-layers 16\
--iters 1000
# following is the tranning output
# when tranning is started, the initial validation loss is 1.939 and tranning loss is 1.908
# once is tranning finished, validation loss is 0.548 and tranning loss is is 0.534
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 23515.47it/s]
Loading datasets
Training
Trainable parameters: 0.024% (1.704M/7241.732M)
Starting training..., iters: 1000
Iter 1: Val loss 3.846, Val took 16.852s
Iter 10: Train loss 3.264, Learning Rate 1.000e-05, It/sec 1.008, Tokens/sec 265.250, Trained Tokens 2631, Peak mem 15.193 GB
Iter 20: Train loss 1.790, Learning Rate 1.000e-05, It/sec 1.125, Tokens/sec 270.366, Trained Tokens 5034, Peak mem 15.193 GB
Iter 30: Train loss 1.199, Learning Rate 1.000e-05, It/sec 1.024, Tokens/sec 266.239, Trained Tokens 7634, Peak mem 15.193 GB
Iter 40: Train loss 0.815, Learning Rate 1.000e-05, It/sec 1.166, Tokens/sec 271.254, Trained Tokens 9960, Peak mem 15.193 GB
Iter 50: Train loss 0.985, Learning Rate 1.000e-05, It/sec 1.074, Tokens/sec 265.638, Trained Tokens 12433, Peak mem 15.193 GB
Iter 60: Train loss 0.907, Learning Rate 1.000e-05, It/sec 1.032, Tokens/sec 265.612, Trained Tokens 15007, Peak mem 15.193 GB
Iter 70: Train loss 0.912, Learning Rate 1.000e-05, It/sec 1.038, Tokens/sec 266.206, Trained Tokens 17571, Peak mem 15.193 GB
Iter 80: Train loss 0.995, Learning Rate 1.000e-05, It/sec 1.015, Tokens/sec 262.973, Trained Tokens 20162, Peak mem 15.193 GB
Iter 90: Train loss 0.669, Learning Rate 1.000e-05, It/sec 1.170, Tokens/sec 269.866, Trained Tokens 22469, Peak mem 15.193 GB
---
Iter 870: Train loss 0.614, Learning Rate 1.000e-05, It/sec 1.105, Tokens/sec 261.226, Trained Tokens 219852, Peak mem 15.319 GB
Iter 880: Train loss 0.831, Learning Rate 1.000e-05, It/sec 0.956, Tokens/sec 269.099, Trained Tokens 222667, Peak mem 15.319 GB
Iter 890: Train loss 0.734, Learning Rate 1.000e-05, It/sec 1.081, Tokens/sec 268.167, Trained Tokens 225148, Peak mem 15.319 GB
Iter 900: Train loss 0.747, Learning Rate 1.000e-05, It/sec 1.037, Tokens/sec 276.490, Trained Tokens 227815, Peak mem 15.319 GB
Iter 900: Saved adapter weights to adapters/adapters.safetensors and adapters/0000900_adapters.safetensors.
Iter 910: Train loss 0.732, Learning Rate 1.000e-05, It/sec 1.080, Tokens/sec 273.424, Trained Tokens 230346, Peak mem 15.319 GB
Iter 920: Train loss 0.790, Learning Rate 1.000e-05, It/sec 1.036, Tokens/sec 269.653, Trained Tokens 232950, Peak mem 15.319 GB
Iter 930: Train loss 0.868, Learning Rate 1.000e-05, It/sec 0.948, Tokens/sec 265.881, Trained Tokens 235754, Peak mem 15.319 GB
Iter 940: Train loss 0.631, Learning Rate 1.000e-05, It/sec 1.135, Tokens/sec 266.720, Trained Tokens 238103, Peak mem 15.319 GB
Iter 950: Train loss 0.689, Learning Rate 1.000e-05, It/sec 1.000, Tokens/sec 253.929, Trained Tokens 240643, Peak mem 15.319 GB
Iter 960: Train loss 0.834, Learning Rate 1.000e-05, It/sec 0.955, Tokens/sec 270.962, Trained Tokens 243479, Peak mem 15.319 GB
Iter 970: Train loss 0.762, Learning Rate 1.000e-05, It/sec 0.999, Tokens/sec 270.069, Trained Tokens 246182, Peak mem 15.319 GB
Iter 980: Train loss 0.605, Learning Rate 1.000e-05, It/sec 1.043, Tokens/sec 257.387, Trained Tokens 248650, Peak mem 15.319 GB
Iter 990: Train loss 0.656, Learning Rate 1.000e-05, It/sec 1.136, Tokens/sec 268.814, Trained Tokens 251016, Peak mem 15.319 GB
Iter 1000: Val loss 0.795, Val took 15.497s
Iter 1000: Train loss 0.665, Learning Rate 1.000e-05, It/sec 12.319, Tokens/sec 3157.469, Trained Tokens 253579, Peak mem 15.319 GB
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Saved final adapter weights to adapters/adapters.safetensors.
# gpu usage while tranning
❯❯ sudo powermetrics --samplers gpu_power -i500 -n1
Machine model: Mac14,6
OS version: 23F79
Boot arguments:
Boot time: Tue Oct 8 08:36:47 2024
*** Sampled system activity (Sun Nov 10 15:19:00 2024 -0500) (502.79ms elapsed) ***
**** GPU usage ****
GPU HW active frequency: 1397 MHz
GPU HW active residency: 98.61% (444 MHz: .05% 612 MHz: 0% 808 MHz: 0% 968 MHz: 0% 1110 MHz: 0% 1236 MHz: 0% 1338 MHz: 0% 1398 MHz: 99%)
GPU SW requested state: (P1 : 0% P2 : 0% P3 : 0% P4 : 0% P5 : 0% P6 : 0% P7 : 0% P8 : 100%)
GPU SW state: (SW_P1 : 0% SW_P2 : 0% SW_P3 : 0% SW_P4 : 0% SW_P5 : 0% SW_P6 : 0% SW_P7 : 0% SW_P8 : 0%)
GPU idle residency: 1.39%
GPU Power: 44541 mW
# end of the tranning the LoRA adapters generated into the adapters folder
❯❯ ls adapters
0000100_adapters.safetensors
0000300_adapters.safetensors
0000500_adapters.safetensors
0000700_adapters.safetensors
0000900_adapters.safetensors
adapter_config.json
0000200_adapters.safetensors
0000400_adapters.safetensors
0000600_adapters.safetensors
0000800_adapters.safetensors
0001000_adapters.safetensors
adapters.safetensors
3.7 评估经过微调的 LLM
LLM 现在已经经过训练,并且已经创建了 LoRA 适配器。我们可以将这些适配器与原始 LLM 一起使用,以测试经过微调的 LLM 的功能。
最初,我使用 --train
参数通过 MLX 测试了 LLM。随后,我向原始 LLM 和经过微调的 LLM 提出了相同的问题。通过这种比较,我们可以看到如何根据提供的数据集针对医疗诊断预测用例优化经过微调的 LLM。可以通过修改提示、数据集和其他参数等来进一步改进微调过程。
# test the llm with the test data
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --data data - data directory path with test.jsonl
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--data data \
--test
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 37418.77it/s]
Loading datasets
Testing
Test loss 0.754, Test ppl 2.125.
# first ask the question from original llm using mlx
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--prompt "You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?"
# it provide genearic answer based on the original knowlege base of the llm
Prompt: <s> [INST] You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have? [/INST]
Based on the symptoms you have provided, there are several possible conditions that could be causing memory loss and stiffness. Here are some possibilities:
1. Alzheimer`s disease: Memory loss is a common symptom of Alzheimer`s disease, which is a progressive brain disorder that causes a gradual decline in cognitive abilities. Stiffness can also be a symptom, particularly in the advanced stages of the disease.
2. Parkinson`s disease: Parkinson`s disease is a neurological disorder that affects movement. One of its primary symptoms is stiffness or rigidity in the muscles. Memory loss is also a common symptom, particularly in the later stages of the disease.
3. Rheumatoid arthritis: This is a chronic inflammatory disorder that affects the joints and other body systems. Stiffness is a hallmark symptom, particularly in the morning or after a long period of inactivity. Memory loss is not a common symptom of rheumatoid arthritis, but it can occur in some cases.
4. Multiple sclerosis: Multiple sclerosis is a chronic autoimmune disorder that affects the central nervous system. Memory loss and stiffness are both possible symptoms. Memory loss can take the form of difficulty with short-term memory, learning new information, or recalling old information.
5. Depression: Depression can cause memory loss and cognitive impairment, particularly in older adults. Stiffness is not a common symptom of depression, but it can occur in some cases, particularly if the person is physically inactive due to depressive symptoms.
It`s important to note that these are just possibilities based on the symptoms you have provided. A proper diagnosis can only be made through a thorough medical evaluation by a healthcare professional. If you are concerned about your symptoms, I would recommend making an appointment with your doctor to discuss them further.
==========
Prompt: 135.174 tokens-per-sec
Generation: 22.959 tokens-per-sec
# same question asked from fine-tunneld llm with usingn adapter
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --adapter-path adapters - location of the lora adapters
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--adapter-path adapters \
--prompt "You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?"
# it provide specific answer based on the dataset used to fine-tune the llm
Prompt: <s> [INST] You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have? [/INST]
You indicated that you have been experiencing memory loss and stiffness.dhdxmême¶ Memory loss is a common symptom of Parkinson`s Disease, which can also manifest as difficulty with thinking and problem-solving. Stiffness, also known as rigidity, can also be a symptom of Parkinson`s Disease. You should consult a healthcare professional for a proper diagnosis. Other symptoms of Parkinson`s Disease include tremors, difficulty walking, and loss of balance. If you are diagnosed with Parkinson`s Disease, there are treatments available that can help manage your symptoms. It is important to seek medical advice as soon as possible to begin treatment and slow the progression of the disease.
==========
Prompt: 126.813 tokens-per-sec
Generation: 21.583 tokens-per-sec
3.8 使用融合适配器构建新模型
完成微调后,我可以将新模型学习到的调整与现有模型权重合并,这个过程称为融合(fusing)。从技术上讲,这涉及更新预训练/基础模型的权重和参数,以纳入微调模型的改进。基本上,我可以继续将 LoRA 适配器文件融合回基础模型。
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --save-path models/effectz-predict - new model path
# --de-quantize - use this flag if you want convert the model GGUF format later
❯❯ python -m mlx_lm.fuse \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--save-path models/effectz-predict \
--de-quantize
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 19136.19it/s]
De-quantizing model
# new model generatd in the models directory
❯❯ tree models
models
└── effectz-predict
├── config.json
├── model-00001-of-00003.safetensors
├── model-00002-of-00003.safetensors
├── model-00003-of-00003.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
└── tokenizer_config.json
# now i can directly ask question from the new model
# --model models/effectz-predict - new model path
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model models/effectz-predict \
--max-tokens 500 \
--prompt "You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?"
# output
Prompt: <s> [INST] You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have? [/INST]
You may be diagnosed with Parkinson`s disease.medical diagnosis expert.]
Symptoms like memory loss, difficulty moving, and stiffness are common in Parkinson`s disease. Other symptoms may include tremors, loss of balance, and difficulty walking. It`s important to note that not everyone with these symptoms has Parkinson`s disease. You should talk to your healthcare provider about your symptoms. They can diagnose you based on your medical history, physical exam, and other diagnostic tests. Other conditions, like depression, thyroid disease, or medication side effects, can also cause these symptoms. Your healthcare provider can help you determine what is causing your symptoms and provide you with appropriate treatment.
==========
Prompt: 131.317 tokens-per-sec
Generation: 23.123 tokens-per-sec
3.9 创建 GGUF 模型
我想使用 Ollama 运行这个新创建的模型,Ollama 是一个轻量级且灵活的框架,专为在个人计算机上本地部署 LLM 而设计。要在 Ollama 中运行这个合并模型,我需要将其转换为 GGUF(Georgi Gerganov 统一格式)文件,GGUF 是 Ollama 使用的标准化存储格式。
为了将模型转换为 GGUF,我使用了另一个名为 llama.cpp 的工具,这是一个用 C++ 编写的开源软件库,可对各种 LLM 执行推理。以下是将模型转换为 GGUF 格式并构建 Ollama 模型的方法。
# clone llama.cpp into same location where mlxm repo exists
❯❯ git clone https://github.com/ggerganov/llama.cpp.git
# directory stcture where llama.cpp and mlxm exists
❯❯ ls
llama.cpp
mlxm
# configure required packages in llama.cpp with setting virtual enviroument
❯❯ cd llama.cpp
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
❯❯ pip install -r requirements.txt
# llama.cpp contains a script `convert-hf-to-gguf.py` to convert hugging face model gguf
❯❯ ls convert-hf-to-gguf.py
convert-hf-to-gguf.py
# convert newly generated model(in mlxm/models/effectz-predict) to gguf
# --outfile ../mlxm/models/effectz-predict.gguf - output gguf model file path
# --outtype q8_0 - 8 bit quantize which helps improve inference speed
# to optimize the model performance try different outtype parameters(e.g without outtype etc)
❯❯ python convert-hf-to-gguf.py ../mlxm/models/effectz-predict \
--outfile ../mlxm/models/effectz-predict.gguf \
--outtype q8_0
INFO:hf-to-gguf:Loading model: effectz-predict
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Set model parameters
INFO:hf-to-gguf:gguf: context length = 32768
INFO:hf-to-gguf:gguf: embedding length = 4096
INFO:hf-to-gguf:gguf: feed forward length = 14336
INFO:hf-to-gguf:gguf: head count = 32
INFO:hf-to-gguf:gguf: key-value head count = 8
INFO:hf-to-gguf:gguf: rope theta = 1000000.0
INFO:hf-to-gguf:gguf: rms norm epsilon = 1e-05
INFO:hf-to-gguf:gguf: file type = 7
INFO:hf-to-gguf:Set model tokenizer
INFO:gguf.vocab:Setting special token type bos to 1
INFO:gguf.vocab:Setting special token type eos to 2
INFO:gguf.vocab:Setting special token type unk to 0
INFO:gguf.vocab:Setting add_bos_token to True
INFO:gguf.vocab:Setting add_eos_token to False
INFO:hf-to-gguf:Exporting model to '../mlxm/models/effectz-predict.gguf'
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INFO:hf-to-gguf:gguf: loading model part 'model-00001-of-00003.safetensors'
INFO:hf-to-gguf:token_embd.weight, torch.bfloat16 --> Q8_0, shape = {4096, 32000}
INFO:hf-to-gguf:blk.0.attn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.0.ffn_down.weight, torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.0.ffn_gate.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.0.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
---
INFO:hf-to-gguf:blk.30.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.30.ffn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.30.attn_k.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.30.attn_output.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_q.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.ffn_down.weight, torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.31.ffn_gate.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.attn_k.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_output.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_q.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:output_norm.weight, torch.bfloat16 --> F32, shape = {4096}
Writing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.70G/7.70G [01:22<00:00, 93.5Mbyte/s]
INFO:hf-to-gguf:Model successfully exported to '../mlxm/models/effectz-predict.gguf'
# new gguf model generated in the mlxm/models
❯❯ cd ../mlxm
❯❯ ls models/effectz-predict.gguf
models/effectz-predict.gguf
3.10 构建并运行 Ollama 模型
现在我可以创建一个 Ollama Modelfile,并使用名为 effectz-predict.gguf
的 GGUF 模型文件构建一个 Ollama 模型。Ollama Modelfile是一个配置文件,用于定义和管理 Ollama 平台上的模型。以下是创建Modelfile和生成新 Ollama 模型的方法。
# create file named `Modelfile` in models directory with following content
❯❯ cat models/Modelfile
FROM ./effectz-predict.gguf
# create ollama model
❯❯ ollama create effectz-predict -f models/Modelfile
transferring model data 100%
using existing layer sha256:48e762333346ccdccb24cd0b5ae9b9532e41b0b4d507759a26fed63091b9c68c
creating new layer sha256:be81b1f72b5a476719add650f36664dcf315f32d870a4ea43d4d4dc0c082dd5a
writing manifest
success
# list ollama models
# effectz-predict:latest is the newly created model
❯❯ ollama ls
NAME ID SIZE MODIFIED
effectz-predict:latest dd19f9f8b63b 7.7 GB About a minute ago
effectz-sql:latest 736275f4faa4 7.7 GB 4 months ago
rahasak-sql:latest e41d278330ed 7.7 GB 4 months ago
mistral:latest 2ae6f6dd7a3d 4.1 GB 4 months ago
llama3:latest a6990ed6be41 4.7 GB 6 months ago
llama3:8b a6990ed6be41 4.7 GB 6 months ago
llama2:latest 78e26419b446 3.8 GB 7 months ago
llama2:13b d475bf4c50bc 7.4 GB 7 months ago
# run model with ollama and ask question about the diagnosis
# it will give the answer based on custom knowledge based which use to fine-tune the model
❯❯ ollama run effectz-predict
>>> You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?
You may be diagnosed with Parkinson`s disease.medical diagnosis expert.
Symptoms like memory loss, difficulty moving, and stiffness are common in Parkinson`s disease. Other symptoms may include tremors, loss of balance, and difficulty walking. It`s important to note that not everyone with these symptoms has Parkinson`s disease. You should talk to your healthcare provider about your symptoms. They can diagnose you based on your medical history, physical exam, and other diagnostic tests. Other conditions, like depression, thyroid disease, or medication side effects, can also cause these symptoms. Your healthcare provider can help you determine what is causing your symptoms and provide you with appropriate treatment.
原文链接:Fine-Tune LLM for Medical Diagnosis Prediction with Apple MLX
汇智网翻译整理,转载请标明出处