LLM微调实现网络攻击检测
在这篇文章中,我将探讨 LLM 微调在网络安全领域预测任务中的更高级用例。具体来说,我提出了一种使用微调 LLM 进行实时网络攻击检测的机制。
在这里,我使用包含恶意和良性数据包的网络 PCAP 数据集对 mistralai/Mistral-7B-Instruct-v0.2
进行了微调。通过使用此数据集训练 mistralai/Mistral-7B-Instruct-v0.2
模型,该模型学会了预测实时网络数据(例如 PCAP 数据)中的异常。
微调过程是在配备 M2 芯片的 Apple Silicon Mac 上使用 LoRA 和 Apple MLX 框架执行的。经过微调后,这些自定义模型将使用 Ollama 运行。与本文相关的所有源代码都已在 GitLab 上发布。请克隆 repo 以继续阅读本文。
1、LLM 微调
微调是调整预训练大型语言模型 (LLM) 的参数或权重以使其专门用于特定任务或领域的过程。虽然像 GPT 这样的预训练语言模型具有广泛的通用语言知识,但它们往往缺乏专业领域的专业知识。
微调通过在特定领域的数据上训练模型来克服这个问题,从而提高其针对目标应用的准确性和有效性。此过程涉及将模型暴露给特定任务的示例,使其能够更深入地掌握领域的细微差别。这一关键步骤将通用语言模型转变为专用工具,从而释放 LLM 在特定领域或应用方面的全部潜力。然而,微调 LLM 需要大量计算资源,例如 GPU,以确保高效训练。
有各种可用的 LLM 微调技术,包括低秩适配器 (LoRA)、量化 LoRA (QLoRA)、参数高效微调 (PEFT)、DeepSpeed 和 ZeRO。在这篇文章中,我将讨论在 Apple MLX 框架中使用 LoRA 技术对 LLM 进行微调。LoRA 由微软研究人员团队于 2021 年首次推出,它提供了一种参数高效的微调方法。与需要对整个基础模型进行微调的传统方法(可能耗时且成本高昂)不同,LoRA 添加了少量可训练参数,同时保持原始模型参数不变。
LoRA 的本质在于向模型添加适配器层,从而提高其效率和适应性。LoRA 不是合并全新的层,而是通过引入低秩矩阵来修改现有层的行为。与完整的模型再训练相比,这种方法引入了最少的附加参数,从而显着降低了计算开销和内存使用量。
通过将调整重点放在特定的模型组件上,LoRA 保留了嵌入在原始权重中的基础知识,从而最大限度地降低了灾难性遗忘的风险。这种有针对性的调整不仅保持了模型的一般能力,而且还实现了快速迭代和特定于任务的增强,使 LoRA 成为微调大型预训练模型的灵活且可扩展的解决方案。
2、RAG 与 LLM 微调
RAG(检索增强生成)通过提供对精选数据库的访问来增强 LLM,使其能够动态检索相关信息以生成响应。相比之下,微调涉及通过在特定的标记数据集上训练模型来调整模型的参数,以提高其在特定任务上的性能。微调会修改模型本身,而 RAG 会扩展模型可以访问的数据。
当你需要使用初始训练时不可用的数据来补充语言模型的提示时,请使用 RAG。这可以包括实时数据、用户特定数据或与提示相关的上下文信息。RAG 非常适合确保模型能够访问最新和最相关的数据。另一方面,微调最适合训练模型以更准确地理解和执行特定任务。
3、使用 Apple MLX 进行 LLM 微调
长期以来,人们一直认为 ML 训练和推理只能在 Nvidia GPU 上执行。然而,随着 ML 框架 MLX 的发布,这种观点发生了变化,它支持在 Apple Silicon CPU/GPU 上进行 ML 训练和推理。
Apple 开发的 MLX 库类似于 TensorFlow 和 PyTorch,并支持 GPU 支持的任务。这个库允许在新的 Apple Silicon(M 系列)芯片上微调 LLM。此外,MLX 支持使用 LoRA 方法进行 LLM 微调。我已经使用 MLX 和 LoRA 成功微调了几个 LLM,包括 Llama-3 和 Mistral。
3.1 用例
在这篇文章中,我将探讨 LLM 微调用于实时网络攻击检测的更高级用例。在这里,我使用包含恶意和良性数据包的 PCAP 数据集对 mistralai/Mistral-7B-Instruct-v0.2 进行微调。通过使用此数据集训练 mistralai/Mistral-7B-Instruct-v0.2 模型,该模型学会预测实时网络数据(例如 PCAP 数据)中的异常。
3.2 设置 MLX 和其他工具
首先,我需要安装 MLX 以及一组必需的工具。下面是我已安装的工具列表,以及我如何设置和配置 MLX 环境。
# used repository called mlxa
❯❯ git clone https://gitlab.com/rahasak-labs/mlxa.git
❯❯ cd mlxa
# create and activate virtial enviroument
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
# install mlx
❯❯ pip install -U mlx-lm
# install other requried pythong pakcages
❯❯ pip install pandas
❯❯ pip install pyarrow
3.3 设置 Huggingface-CLI
我从 Hugging Face 获取 LLM(基础模型)和数据集。为此,我需要在 Hugging Face 上设置一个帐户并配置 huggingface-cli
命令行工具。
# setup account in hugging-face from here
https://huggingface.co/welcome
# create access token to read/write data from hugging-face through the cli
# this token required when login to huggingface cli
https://huggingface.co/settings/tokens
# setup hugginface-cli
❯❯ pip install huggingface_hub
❯❯ pip install "huggingface_hub[cli]"
# login to huggingface through cli
# it will ask the access token previously created
❯❯ huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
Setting a new token will erase the existing one.
To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/lambda.eranga/.cache/huggingface/token
Login successful
# once login the tokne will be saved in the ~/.cache/huggingface
❯❯ ls ~/.cache/huggingface
datasets
hub
token
3.4 准备数据集
MLX 要求数据采用特定格式。MLX 中讨论了三种主要格式:聊天、完成和文本。你可以在此处阅读有关这些数据格式的更多信息。
对于此用例,我将使用完成格式数据,它将提示和完成信息组合在一个自然语言短语中。此格式要求提示是 llm 的输入,完成是 llm 的响应。数据集生成在 LLM 的微调中起着至关重要的作用,因为它直接影响微调模型的准确性。可以采用各种技术来生成用于微调 LLM 的数据集。
原始数据集的结构为 .csv
文件。为了使用 Apple MLX 进行微调,我将其转换为基于完成的格式。每个文本记录都包含提示和完成,形成一个单一的、有凝聚力的自然语言短语。以下是原始 PCAP CSV 数据和转换后的完成格式。
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 34293, network protocol: tcp, duration of the connection: 2.998556, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 41106, network protocol: tcp, duration of the connection: 2.998779, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 45986, network protocol: tcp, duration of the connection: 2.998807, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.1, source port: 3, network protocol: icmp, duration of the connection: 1.999924, connection state: OTH, missed_bytes: 0, number of packets sent from source to destination: 2, number of ip bytes sent from source to destination: 136, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "normal network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 51216, network protocol: tcp, duration of the connection: 2.999077, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 51022, network protocol: tcp, duration of the connection: 2.999300, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
MLX 需要三组数据集:训练、测试和有效,以进行微调。数据文件应为 JSONL 格式。下面的脚本将 CSV 文件中的 PCAP 数据转换为 JSONL 格式。在转换过程中,它将 CSV 记录中的 PCAP 数据属性组合起来,并生成单个自然语言短语提示。根据 CSV 记录中的标签字段生成完成字段。
import pandas as pd
import json
import random
# Read the CSV file
csv_file_path = "pcap-labeled.csv"
pcap_data = pd.read_csv(csv_file_path, delimiter="|")
# Prepare the JSONL file content
jsonl_data = []
# Fields to check for '-'
fields_to_check = [
'id.orig_p', 'id.orig_p',
'proto', 'conn_state',
'duration', 'missed_bytes',
'orig_pkts', 'resp_pkts',
'orig_ip_bytes', 'resp_ip_bytes'
]
for _, row in pcap_data.iterrows():
# Skip rows with '-' in any of the specified fields
if any(row[field] == '-' for field in fields_to_check):
continue
# constrct lable
label = "anomaly" if row['label'] == "Malicious" else "normal"
# Create the JSONL record
record = {
"prompt": (
f"You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. "
f"Here are the attributes, "
f"'source ip address: {row['id.orig_h']}, source port: {row['id.orig_p']}, "
f"network protocol: {row['proto']}, "
f"duration of the connection: {row['duration']}, "
f"connection state: {row['conn_state']}, "
f"missed_bytes: {row['missed_bytes']}, number of packets sent from source to destination: {row['orig_pkts']}, "
f"number of ip bytes sent from source to destination: {row['orig_ip_bytes']}, number of packets sent from destination to source: {row['resp_pkts']}, "
f"number of ip bytes sent from destination to source: {row['resp_ip_bytes']}'. "
),
"completion": f"{label} network traffic"
}
jsonl_data.append(record)
# Shuffle the data
random.shuffle(jsonl_data)
# Calculate split indices for 2/3 train and 1/3 test
train_split = int(len(jsonl_data) * 7 / 10)
# Split the data
train_data = jsonl_data[:train_split]
test_data = jsonl_data[train_split:]
# Save train.jsonl
train_file_path = 'train.jsonl'
with open(train_file_path, 'w', encoding='utf-8') as train_file:
for entry in train_data:
train_file.write(json.dumps(entry) + '\n')
# Save test.jsonl
test_file_path = 'test.jsonl'
with open(test_file_path, 'w', encoding='utf-8') as test_file:
for entry in test_data:
test_file.write(json.dumps(entry) + '\n')
# Save valid.jsonl
valid_file_path = 'valid.jsonl'
with open(valid_file_path, 'w', encoding='utf-8') as valid_file:
for entry in test_data:
valid_file.write(json.dumps(entry) + '\n')
我已将数据集放在数据目录中。随后,我运行脚本来生成训练、测试和有效数据集的 JSONL
格式文件。以下是生成的数据文件的结构。
# activate virtual env
❯❯ source .venv/bin/activate
# data directory
❯❯ ls -al data
prepare.py
s2d.csv
# generate jsonl files
❯❯ cd data
❯❯ python prepare.py
# generated files
❯❯ ls -ls
test.jsonl
train.jsonl
valid.jsonl
# train.jsonl
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 33037, network protocol: tcp, duration of the connection: 2.998781, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 45506, network protocol: tcp, duration of the connection: 2.998830, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 39972, network protocol: tcp, duration of the connection: 2.998819, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 40262, network protocol: tcp, duration of the connection: 2.998574, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 60633, network protocol: tcp, duration of the connection: 2.999060, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
# test.jsonl
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 55520, network protocol: tcp, duration of the connection: 2.998521, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 45461, network protocol: tcp, duration of the connection: 2.998792, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 37168, network protocol: tcp, duration of the connection: 2.998772, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "normal network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 60316, network protocol: tcp, duration of the connection: 2.998572, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 53036, network protocol: tcp, duration of the connection: 2.998555, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
# valid.jsonl
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 46557, network protocol: tcp, duration of the connection: 2.998811, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47863, network protocol: tcp, duration of the connection: 2.998817, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 38867, network protocol: tcp, duration of the connection: 2.998996, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 54329, network protocol: tcp, duration of the connection: 2.998824, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 60526, network protocol: tcp, duration of the connection: 2.999080, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
3.6 微调/训练 LLM
下一步是使用我之前准备的数据集,通过 MLX 微调 Mistral-7B LLM。最初,我使用 huggingface-cli 从 Hugging Face 下载了 mistralai/Mistral-7B-Instruct-v0.2 LLM。然后,我使用提供的数据集和 LoRA 训练了 LLM。LoRA 或低秩自适应涉及引入低秩矩阵来调整模型的行为,而无需进行大量重新训练,从而保留原始模型参数,同时实现高效、有针对性的自适应。
在配备 64GB RAM 和 30 个 GPU 的 Mac M2 上,训练过程大约需要 25 分钟来训练 LLM 并生成必要的适配器。
# download llm
❯❯ huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2
/Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/1296dc8fd9b21e6424c9c305c06db9ae60c03ace
# model is downloaded into ~/.cache/huggingface/hub/
❯❯ ls ~/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
blobs refs snapshots
# list all downloaded models from huggingface
❯❯ huggingface-cli scan-cache
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
-------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------
BAAI/bge-reranker-base model 1.1G 6 3 months ago 4 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--BAAI--bge-reranker-base
NousResearch/Meta-Llama-3-8B model 16.1G 14 2 months ago 5 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--NousResearch--Meta-Llama-3-8B
gpt2 model 2.9M 5 8 months ago 8 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--gpt2
infgrad/stella_en_1.5B_v5 model 240.7K 6 4 months ago 4 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--infgrad--stella_en_1.5B_v5
mistralai/Mistral-7B-Instruct-v0.2 model 29.5G 21 1 day ago 1 day ago main /Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
sentence-transformers/all-MiniLM-L6-v2 model 91.6M 11 8 months ago 8 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2
Done in 0.0s. Scanned 6 repo(s) for a total of 46.8G.
Got 1 warning(s) while scanning. Use -vvv to print details
# fine-tune llm
# --model - original model which download from huggin face
# --data data - data directory path with train.jsonl
# --batch-size 4 - batch size
# --lora-layers 16 - number of lora layers
# --iters 1000 - tranning iterations
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--data data \
--train \
--batch-size 4\
--lora-layers 16\
--iters 1000
# following is the tranning output
# when tranning is started, the initial validation loss is 1.939 and tranning loss is 1.908
# once is tranning finished, validation loss is 0.548 and tranning loss is is 0.534
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 17980.26it/s]
Loading datasets
Training
Trainable parameters: 0.024% (1.704M/7241.732M)
Starting training..., iters: 1000
Iter 1: Val loss 3.346, Val took 35.688s
Iter 10: Train loss 2.713, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 314.028, Trained Tokens 6640, Peak mem 16.083 GB
Iter 20: Train loss 1.179, Learning Rate 1.000e-05, It/sec 0.474, Tokens/sec 314.641, Trained Tokens 13280, Peak mem 16.083 GB
Iter 30: Train loss 0.432, Learning Rate 1.000e-05, It/sec 0.474, Tokens/sec 314.606, Trained Tokens 19920, Peak mem 16.083 GB
Iter 40: Train loss 0.484, Learning Rate 1.000e-05, It/sec 0.476, Tokens/sec 313.707, Trained Tokens 26511, Peak mem 16.083 GB
Iter 50: Train loss 0.361, Learning Rate 1.000e-05, It/sec 0.471, Tokens/sec 311.593, Trained Tokens 33123, Peak mem 16.083 GB
Iter 60: Train loss 0.335, Learning Rate 1.000e-05, It/sec 0.472, Tokens/sec 312.100, Trained Tokens 39739, Peak mem 16.083 GB
Iter 70: Train loss 0.315, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 313.810, Trained Tokens 46367, Peak mem 16.083 GB
Iter 80: Train loss 0.311, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 313.270, Trained Tokens 52994, Peak mem 16.083 GB
---
Iter 910: Train loss 0.165, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 311.057, Trained Tokens 603153, Peak mem 16.213 GB
Iter 920: Train loss 0.154, Learning Rate 1.000e-05, It/sec 0.469, Tokens/sec 310.728, Trained Tokens 609781, Peak mem 16.213 GB
Iter 930: Train loss 0.158, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 311.830, Trained Tokens 616409, Peak mem 16.213 GB
Iter 940: Train loss 0.148, Learning Rate 1.000e-05, It/sec 0.465, Tokens/sec 308.425, Trained Tokens 623037, Peak mem 16.213 GB
Iter 950: Train loss 0.156, Learning Rate 1.000e-05, It/sec 0.468, Tokens/sec 310.664, Trained Tokens 629669, Peak mem 16.213 GB
Iter 960: Train loss 0.166, Learning Rate 1.000e-05, It/sec 0.466, Tokens/sec 308.836, Trained Tokens 636297, Peak mem 16.213 GB
Iter 970: Train loss 0.150, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 311.695, Trained Tokens 642925, Peak mem 16.213 GB
Iter 980: Train loss 0.151, Learning Rate 1.000e-05, It/sec 0.474, Tokens/sec 314.600, Trained Tokens 649564, Peak mem 16.213 GB
Iter 990: Train loss 0.146, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 312.381, Trained Tokens 656204, Peak mem 16.213 GB
Iter 1000: Val loss 0.149, Val took 35.181s
Iter 1000: Train loss 0.146, Learning Rate 1.000e-05, It/sec 4.690, Tokens/sec 3106.646, Trained Tokens 662828, Peak mem 16.213 GB
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Saved final adapter weights to adapters/adapters.safetensors.
# gpu usage while trainning
❯❯ sudo powermetrics --samplers gpu_power -i500 -n1
Password:
Machine model: Mac14,6
OS version: 23F79
Boot arguments:
Boot time: Fri Dec 6 09:26:57 2024
*** Sampled system activity (Mon Dec 9 08:43:34 2024 -0500) (503.37ms elapsed) ***
**** GPU usage ****
GPU HW active frequency: 1397 MHz
GPU HW active residency: 98.22% (444 MHz: .09% 612 MHz: 0% 808 MHz: 0% 968 MHz: 0% 1110 MHz: 0% 1236 MHz: 0% 1338 MHz: 0% 1398 MHz: 98%)
GPU SW requested state: (P1 : 0% P2 : 0% P3 : 0% P4 : 0% P5 : 0% P6 : 0% P7 : 0% P8 : 100%)
GPU SW state: (SW_P1 : 0% SW_P2 : 0% SW_P3 : 0% SW_P4 : 0% SW_P5 : 0% SW_P6 : 0% SW_P7 : 0% SW_P8 : 0%)
GPU idle residency: 1.78%
GPU Power: 42023 mW
# end of the tranning the LoRA adapters generated into the adapters folder
❯❯ ls adapters
0000100_adapters.safetensors
0000300_adapters.safetensors
0000500_adapters.safetensors
0000700_adapters.safetensors
0000900_adapters.safetensors
adapter_config.json
0000200_adapters.safetensors
0000400_adapters.safetensors
0000600_adapters.safetensors
0000800_adapters.safetensors
0001000_adapters.safetensors
adapters.safetensors
3.7 评估经过微调的 LLM
LLM 现已经过训练,并且已创建 LoRA 适配器。我们可以将这些适配器与原始 LLM 结合使用,以测试经过微调的 LLM 的功能。最初,我使用 --train 参数通过 MLX 测试了 LLM。随后,我向原始 LLM 和微调后的 LLM 提出了同样的问题。通过这种比较,我们可以看到微调后的 LLM 如何根据提供的数据集针对医疗诊断预测用例进行优化。可以通过修改提示、数据集和其他参数等来进一步改进微调过程。
# test the llm with the test data
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --data data - data directory path with test.jsonl
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--data data \
--test
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 54471.48it/s]
Loading datasets
Testing
Test loss 0.151, Test ppl 1.163.
# first ask the question from original llm using mlx
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--prompt "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'."
# it provide genearic answer based on the original knowlege base of the llm
Prompt: <s> [INST] You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. [/INST]
Based on the provided network traffic attributes, the traffic appears to be normal. Here`s why:
1. Source IP address: The IP address 192.168.100.103 falls within a private IP address range, which is commonly used for internal networks. This is not an indicator of anomalous traffic.
2. Source port: The source port number 47525 is not associated with any known malicious activity. It is a high-numbered ephemeral port, which is commonly used for outgoing connections from applications.
3. Network protocol: The use of TCP (Transmission Control Protocol) is a normal and common protocol used for reliable data transfer between applications.
4. Duration of the connection: The connection duration of 2.998777 seconds is within the normal range for a typical connection.
5. Connection state: The connection state S0 indicates that the connection is still in the initial SYN (synchronize) stage, which is normal for the beginning of a connection.
6. Missed bytes and packets: The absence of missed bytes and packets indicates that all data was received as intended.
7. Number of packets and IP bytes: The number of packets and IP bytes sent and received is within a reasonable range for a typical connection.
Therefore, based on the given network traffic attributes, the traffic appears to be normal.
==========
Prompt: 348.322 tokens-per-sec
Generation: 20.630 tokens-per-sec
# same question asked from fine-tunneld llm with usingn adapter
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --adapter-path adapters - location of the lora adapters
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--adapter-path adapters \
--prompt "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'."
# it provide specific answer based on the dataset used to fine-tune the llm
Prompt: <s> [INST] You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. [/INST]
anomaly network traffic
==========
Prompt: 351.853 tokens-per-sec
Generation: 21.462 tokens-per-sec
3.8 使用融合适配器构建新模型
完成微调后,我可以将这个新模型学习到的调整与现有模型权重合并,这个过程称为融合。从技术角度来说,这涉及更新预训练/基础模型的权重和参数,以纳入微调模型的改进。基本上,我可以继续将 LoRA 适配器文件融合回基础模型。
完成微调过程后,我可以将新模型学习到的调整与现有模型权重合并,这个过程称为融合。从技术角度来说,这涉及更新预训练/基础模型的权重和参数,以纳入微调模型的改进。本质上,我可以继续将 LoRA 适配器融合回基础模型。
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --save-path models/effectz-attack - new model path
# --de-quantize - use this flag if you want convert the model GGUF format later
❯❯ python -m mlx_lm.fuse \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--save-path models/effectz-attack \
--de-quantize
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 19929.74it/s]
De-quantizing model
# new model generatd in the models directory
❯❯ tree models
models
└── effectz-attack
├── config.json
├── model-00001-of-00003.safetensors
├── model-00002-of-00003.safetensors
├── model-00003-of-00003.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
└── tokenizer_config.json
# now i can directly ask question from the new model
# --model models/effectz-attack - new model path
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model models/effectz-attack \
--max-tokens 500 \
--prompt "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. "
# output
Prompt: <s> [INST] You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. [/INST]
anomaly network traffic
==========
Prompt: 346.093 tokens-per-sec
Generation: 23.113 tokens-per-sec
3.9 创建 GGUF模型
我想用 Ollama 运行这个新创建的模型,Ollama 是一个轻量级且灵活的框架,专为在个人计算机上本地部署 LLM 而设计。要在 Ollama 中运行这个合并模型,我需要将其转换为 GGUF(Georgi Gerganov 统一格式)文件。GGUF 是 Ollama 使用的标准化存储格式。为了将模型转换为 GGUF,我使用了另一个名为 llama.cpp 的工具,这是一个用 C++ 编写的开源软件库,可对各种 LLM 执行推理。以下是将模型转换为 GGUF 格式并构建 Ollama 模型的方法。
# clone llama.cpp into same location where mlxa repo exists
❯❯ git clone https://github.com/ggerganov/llama.cpp.git
# directory stcture where llama.cpp and mlxa exists
❯❯ ls
llama.cpp
mlxa
# configure required packages in llama.cpp with setting virtual enviroument
❯❯ cd llama.cpp
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
❯❯ pip install -r requirements.txt
# llama.cpp contains a script `convert-hf-to-gguf.py` to convert hugging face model gguf
❯❯ ls convert-hf-to-gguf.py
convert-hf-to-gguf.py
# convert newly generated model(in mlxa/models/effectz-attack) to gguf
# --outfile ../mlxa/models/effectz-attack.gguf - output gguf model file path
# --outtype q8_0 - 8 bit quantize which helps improve inference speed
# to optimize the model performance try different outtype parameters(e.g without outtype etc)
❯❯ python convert-hf-to-gguf.py ../mlxa/models/effectz-attack \
--outfile ../mlxa/models/effectz-predict.gguf \
--outtype q8_0
# output
INFO:hf-to-gguf:Loading model: effectz-attack
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Set model parameters
INFO:hf-to-gguf:gguf: context length = 32768
INFO:hf-to-gguf:gguf: embedding length = 4096
INFO:hf-to-gguf:gguf: feed forward length = 14336
INFO:hf-to-gguf:gguf: head count = 32
INFO:hf-to-gguf:gguf: key-value head count = 8
INFO:hf-to-gguf:gguf: rope theta = 1000000.0
INFO:hf-to-gguf:gguf: rms norm epsilon = 1e-05
INFO:hf-to-gguf:gguf: file type = 7
INFO:hf-to-gguf:Set model tokenizer
INFO:gguf.vocab:Setting special token type bos to 1
INFO:gguf.vocab:Setting special token type eos to 2
INFO:gguf.vocab:Setting special token type unk to 0
INFO:gguf.vocab:Setting add_bos_token to True
INFO:gguf.vocab:Setting add_eos_token to False
---
INFO:hf-to-gguf:blk.29.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.30.attn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.30.ffn_down.weight, torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.30.ffn_gate.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.30.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.30.ffn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.30.attn_k.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.30.attn_output.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_q.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.ffn_down.weight, torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.31.ffn_gate.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.attn_k.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_output.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_q.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:output_norm.weight, torch.bfloat16 --> F32, shape = {4096}
Writing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.70G/7.70G [01:21<00:00, 94.9Mbyte/s]
INFO:hf-to-gguf:Model successfully exported to '../mlxa/models/effectz-attack.gguf'
# new gguf model generated in the mlxm/models
❯❯ cd ../mlxa
❯❯ ls models/effectz-attack.gguf
models/effectz-attack.gguf
3.10 构建并运行 Ollama 模型
现在我可以创建一个 Ollama Modelfile,并使用名为 effectz-predict.gguf
的 GGUF 模型文件构建一个 Ollama 模型。Ollama Modelfile 是一个配置文件,用于定义和管理 Ollama 平台上的模型。以下是创建 Modelfile 和生成新 Ollama 模型的方法。
# create file named `Modelfile` in models directory with following content
❯❯ cat models/Modelfile
FROM ./effectz-attack.gguf
# create ollama model
❯❯ ollama create effectz-attack -f models/Modelfile
transferring model data 100%
using existing layer sha256:189106bcbbfa1942e25555311d0097b1b2604d4a44416ae90436763ea8e17886
creating new layer sha256:633247d6e759ef95ed88aba596b6882716cdf3888236ab14762070b42126f039
writing manifest
success
# list ollama models
# effectz-attack:latest is the newly created model
❯❯ ollama ls
NAME ID SIZE MODIFIED
effectz-attack:latest 7aa76a0412bc 7.7 GB 44 seconds ago
effectz-predict:latest e87d93dff4c4 14 GB 4 weeks ago
effectz-sql:latest 736275f4faa4 7.7 GB 5 months ago
rahasak-sql:latest e41d278330ed 7.7 GB 5 months ago
mistral:latest 2ae6f6dd7a3d 4.1 GB 5 months ago
llama3:latest a6990ed6be41 4.7 GB 7 months ago
llama3:8b a6990ed6be41 4.7 GB 7 months ago
llama2:latest 78e26419b446 3.8 GB 8 months ago
llama2:13b d475bf4c50bc 7.4 GB 8 months ago
# run model with ollama and ask question about the diagnosis
# it will give the answer based on custom knowledge based which use to fine-tune the model
❯❯ ollama run effectz-attack
>>> You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'.
anomaly network traffic
原文链接:Fine-Tune LLM for Real-Time Network Attach Detection with Apple MLX
汇智网翻译整理,转载请标明出处