用小模型合成表格数据
合成数据生成解决了多个基本挑战:数据集中的类别不平衡、数据隐私要求、数据获取成本优化和实验周期加速。传统方法(如 SMOTE [1])通过在现有数据点之间进行插值来为少数类生成合成样本。之前的博客文章 [2] 对表格合成(数值)数据生成的生成方法进行了全面评估,包括生成对抗网络 (GAN)、变分自动编码器 (VAE)、高斯 Copula、贝叶斯网络和条件表格 GAN (CTGAN)。
这篇文章研究了利用小语言模型 (SLM) 生成合成表格数值数据的新方法。与之前的研究保持连续性,我们专注于单一表格数据,特别是分析来自 NASA 艾姆斯预测卓越中心的涡扇发动机退化模拟数据集 [3][4]。有关数据集特征和研究动机,读者可以参考之前的出版物。
该研究考察了四种关键方法:
- 具有领域特定约束的 SLM 微调
- 使用数值标记器和自定义损失函数进行高级微调
- Transformer GAN 和条件 Transformer GAN 架构
- 语言模型 GAN (LM-GAN) 实现
将语言模型归类为“小型”在自然语言处理领域表现出时间变化。一个值得注意的例子是 GPT-2,它在 2019 年发布时具有 15 亿个参数,被归类为大型模型,但现在按照当代标准被认为是小型的。当前分类 (2024) 将 SLM 定义为包含 3-100 亿个参数的模型,而大型语言模型 (LLM) 通常包含数千亿个参数。SLM 针对资源效率和边缘部署场景进行了优化,代表性模型包括 Phi 3 [8]、Galactica 和 Gemma。
SLM 的架构多样性与其较大的同类产品相似,包含各种注意力机制:
- 多头注意力 (MHA)
- 多查询注意力 (MQA)
- 组查询注意力 (GQA)
- 多头潜在注意力 (MLA)
这些模型在其结构组件中表现出显著的变化 [26],包括:
- 前馈神经网络实现(标准 FFN、门控 FFN)
- 激活函数选择(ReLU、GELU、GELUtanh、SiLU)
- 词汇量范围(<50K 到 250K+)
- 训练数据量(从数百万到 6T 令牌不等)
本研究重点是利用这些 SLM 架构进行合成数据生成应用。
本研究评估了使用小型语言模型 (SLM) 进行合成表格数据生成的四种高级方法,并简要概述了具有结构化约束的快速工程。针对特定制造目标对 SLM 进行微调、结合自定义损失函数的高级微调、将 Transformers 与 GAN 相结合的混合架构以及新颖的语言模型 GAN (LM-GAN) 框架。该实现利用 Microsoft 的 Phi-3.5 mini [13] 进行微调实验,使用 DistilGPT-2 [23] 进行自定义损失函数集成,并为 Transformer GAN 和 LM-GAN 模型开发专用架构。每种方法都逐步建立在 SLM 的功能之上,最终形成 LM-GAN 架构,该架构展示了合成制造数据中统计属性的卓越保存。
1、提示工程
提示工程代表了一种基本而有效的合成数据生成方法。该方法需要为语言模型构建结构化输入,包括字段描述和规范、领域特定约束和具有少量学习的样本数据样本。该方法允许指定生成参数,包括输出量和条件约束。研究表明,有两种主要的提示范式:自然语言描述 [10] 和结构化 CSV 格式 [9],每种范式都优化了 token 利用效率的不同方面。高级提示技术结合了随机值替换以增强数据多样性和分层分组机制,以实现条件生成 [11][12]。虽然提示工程为合成数据生成提供了巨大的潜力,但本文重点介绍更先进的方法框架,将提示技术的全面探索留给现有文献和读者进一步探索提示工程
2、SLM 微调
我们探索的第一种技术是微调语言模型。本文为此考虑了 Microsoft Phi 3.5 mini instruct 模型 [13]。它是一个仅解码器的 transformer 架构模型,具有 3.82B 参数,在 3.4 万亿个 token 上进行训练。它的上下文长度为 128k 到kens,并针对基本推理任务、代码生成和数学问题解决进行了专门优化。
为了进一步说明,让我们使用表格数值数据生成来正式构建问题。
假设D为具有 n 个样本的训练数据集 D = {X1、X2、X3、…、Xn}
。每个样本都有一组 m 个键值对,以字段名称作为键。
这些分布在 n 行中,其中对由一个特殊标记分隔。对于此练习,选择 :::
作为标记。连接运算符 C 执行此操作并生成组合字符串。
其中 δ 表示特殊标记。生成的文本被进一步标记。标记化函数 T 将连接的字符串映射到一系列标记:
最后,这些标记用于微调语言模型,该模型根据种子标记预测下一个标记。
对于合成数据生成,给定种子标记 s,模型从学习的分布 X^ 中采样。
上面的图 1 解释了这个过程。标有 F{i}
的步骤用于微调过程,标有 G{i}
的步骤用于生成。
种子标记的选择取决于模型的微调方式。如果字段在连接运算符之前被随机化,则种子标记可以是任何字段名称。或者,它可以用一组确定其他字段分布的固定字段名称进行调节。你可以在文献中找到这两种技术都经过了评估,结果因数据集而异[10]。虽然这篇文章利用了通用字段名称,但有研究表明,使用描述性字段名称可以进一步提高合成样本的保真度[14]。
为了微调模型,公开可用的(通过 Huggingface)Phi 3.5 mini模型的权重[15]被用作预训练权重。使用 AdamW 优化器[16]对模型进行微调,该优化器具有恒定学习率调度程序和 2e − 4 的学习率。与标准 Adam 相比,AdamW 优化器提供了更好的泛化性能,并且一直在微调任务[17][18]。这是因为它将权重衰减与跟踪一阶和二阶矩及其各自的权重衰减分离。对于微调,Huggingface transformers[19] 库的训练器类与低秩自适应[20] 和 BitsAndBytes[21] 量化一起使用。
以下配置提供了微调设置的快速快照
model_id: &model_id "microsoft/Phi-3.5-mini-instruct"
tokenizer_config:
max_length: 350
truncation: True
padding: "max_length"
training_env:
model_dir: "opt/ml/phi-model" #directory for fine tuned model weights
cache_dir: &cache_dir "/tmp/.cache" #directory for storing pretrained model weights(downloaded from huggingface)
merge_dir: "/tmp/phi-model" #directory for storing merged model weights
model_config:
trust_remote_code: True
cache_dir: *cache_dir
device_map: "auto"
torch_dtype: "float16"
attn_implementation: "flash_attention_2"
bnb_config:
load_in_4bit: True
bnb_4bit_use_double_quant: True
bnb_4bit_quant_type: "nf4"
bnb_4bit_compute_dtype: "bfloat16"
lora_config:
r: 8
lora_alpha: 16
lora_dropout: 0.1
bias: "none"
task_type: "CAUSAL_LM"
training_config:
per_device_train_batch_size: 4
per_device_eval_batch_size: 1
gradient_accumulation_steps: 2
gradient_checkpointing: True
learning_rate: 0.0002
lr_scheduler_type: "constant"
num_train_epochs: 1
logging_strategy: "steps"
logging_steps: 10
log_on_each_node: False
bf16: True
ddp_find_unused_parameters: False
fsdp: "" #fsdp turned off
fsdp_config: null
save_strategy: "no"
output_dir: "outputs"
report_to: none
optim: adamw_torch
save_strategy: epoch
max_grad_norm: 0.3
warmup_ratio: 0.03
下图 2 提供了原始(测试)分布和合成分布之间的 Kullback-Leibler (KL) 散度。KL 散度值跨度约为两个数量级,范围从 ~0.5 到 ~40,表明合成数据质量在不同设置下存在显著差异。对于其他字段,观察到的 KL 散度值较低,表示保真度较高。
3、高级微调
接下来,这篇文章研究了在训练过程中考虑 KL 散度的影响,并将其视为损失函数的一部分。在本节中,我们将研究使用 DistilGPT-2 语言模型生成合成数据的高级微调策略的实现。它是 GPT2[22] 的压缩版本,通过知识蒸馏开发而成。它有 8200 万个参数,是使用知识蒸馏开发的,在 1.24 亿个参数版本的 GPT-2 的监督下进行了预训练,大约是父模型的一半大小。其架构包含 6 个转换器块(GPT2:12),同时保留每个块的 12 个注意力头和 768 的嵌入维度。该模型保留了 1024 个标记的上下文窗口和 50,257 个标记的词汇量[23]。
本文探讨了使用两种策略对 GPT2 进行微调
3.1 数值分词器
数值分词器的动机有两个方面。首先,它为数值中的所有数字提供相同的权重,而默认标记器可以在不同的数字处拆分以生成标记。第二个动机是为列值、分隔符标记( :::
)和字段分隔符( ,
)创建专用的单数标记。它可以正式定义为基本 GPT-2 标记器 Tb 的扩展,如下所示:
Tn 是拆分数值的数值标记函数。
标记器的词汇表定义为:
其中 V 是结果词汇表,Vb 是基本 GPT-2 词汇表,F 是字段名称集,D 是数字和小数点集。
可以在下面找到此代码:
from transformers import AutoTokenizer
import re
from typing import List, Union, Dict
class NumTokenizer:
def __init__(self,model_id,reqd_cols,sep_token):
self.base_tokenizer = AutoTokenizer.from_pretrained(model_id)
self.base_tokenizer.pad_token = self.base_tokenizer.eos_token # Set Tokenizer pad Token
self.base_tokenizer.special_tokens = [sep_token]
self.base_tokenizer.special_tokens.extend(reqd_cols)
self.base_tokenizer.special_tokens.extend(',')
self.num_pattern = re.compile(rf'{sep_token.strip()}\d+(?:\.\d+)?')
def tokenize_num(self,num_text):
return [t for t in num_text]
def __call__(self,text: Union[str,List[str]],padding:bool=True,truncation:bool=True,max_length:int=None,return_tensors:str=None, **kwargs) -> Dict:
if isinstance(text,str):
tl = [text]
else:
tl = text
encoded_inputs = []
for t in tl:
encoded = self.encode(t)
encoded_inputs.append(encoded)
if max_length is None:
max_length = max(len(e) for e in encoded_inputs)
if padding:
encoded_inputs = [enc + [self.base_tokenizer.pad_token_id] * (max_length - len(enc)) for enc in encoded_inputs]
if truncation:
encoded_inputs = [enc[:max_length] for enc in encoded_inputs]
if isinstance(text,str):
output = {"input_ids": encoded_inputs[0], "attention_mask" : [1] * len(encoded_inputs[0])}
else:
output = {"input_ids": encoded_inputs, "attention_mask" : [[1] * len(enc) for enc in encoded_inputs]}
if return_tensors:
output = {k: torch.tensor(v) for k,v in output.items()}
return output
def tokenize(self,text):
col_names = self.num_pattern.split(text)
col_values = [n.replace(sep_token.strip(),'').strip() for n in self.num_pattern.findall(text)]
tokens = []
for col_name, col_value in zip (col_names,col_values + ['']):
tokens.extend(self.base_tokenizer.tokenize(col_name))
tokens.extend(self.tokenize_num(col_value))
return tokens
def encode(self,text, **kwargs):
tokens = self.tokenize(text)
return self.base_tokenizer.convert_tokens_to_ids(tokens)
def decode(self,token_ids, **kwargs):
return self.base_tokenizer.decode(token_ids)
def __getattr__(self,name):
return getattr(self.base_tokenizer,name)
3.2 带有 KL 散度损失的自定义训练器
第二个动机引入了一种混合损失函数,旨在平衡维护语言结构(字段名称)和数值结构(字段值)。损失函数将文本连贯性的交叉熵损失与数值准确性的 KL 散度相结合,由参数 α 和 β 加权,值分别设置为 0.6 和 0.4。这种公式使模型能够学习表格数据表示的结构模式和数值的底层分布,这种方法的目标是解决合成数据生成中的基本挑战:保留字段之间的语义关系和数值的统计属性。
文本损失 Ltext 是用交叉熵函数计算的,该函数根据每个元素的预测概率分布计算真实类标签的负对数似然,然后计算所有元素中这些损失的平均值。
对于数值,预测分布和真实分布之间的 KL 散度 Lnum:
实现如下:
import torch.nn.functional as F
def kl_div(p, q):
return (p * torch.log(p / q)).sum(-1)
class CustomTrainer(transformers.Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_loss(self,model,inputs,num_items_in_batch,return_outputs=False):
labels = inputs.pop("labels")
labels = torch.where(labels == -100, tokenizer.pad_token_id,labels)
outputs = model(**inputs)
logits = outputs.logits
logits_m = torch.argmax(logits[:, :-1, :], dim=-1)
labels_m = labels[:, 1:].contiguous()
pred_texts = tokenizer.batch_decode(logits_m,skip_special_tokens=False)
label_texts = tokenizer.batch_decode(labels_m,skip_special_tokens=False)
pred_list = text2dict(pred_texts,df_meta_data,sep_token)
label_list = text2dict(label_texts,df_meta_data,sep_token)
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
#column names losses
text_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.reshape(-1),
ignore_index=tokenizer.pad_token_id,)
#column val loss with KL diveregence
num_losses = []
for pred_dict, label_dict in zip(pred_list, label_list):
for key in pred_dict.keys():
if key != '' and key in label_dict:
pred_value = pred_dict[key]
true_value = label_dict[key]
if pred_value is not None and true_value is not None:
num_losses.append(kl_div(torch.tensor([pred_value], device=logits.device), torch.tensor([true_value], device=logits.device)))
num_loss = torch.mean(torch.stack(num_losses)) if num_losses else 0
alpha = 0.6 # Weightage for text loss
beta = 0.4 # Weightage for numerical loss
combined_loss = alpha * text_loss + beta * num_loss
if return_outputs:
return combined_loss, outputs
else:
return combined_loss
不同列的 KL 散度测量结果揭示了模型捕获底层分布的能力存在显著差异。总体而言,与原始微调相比,这些分布的质量有所提高,但字段 s16 和 s18 的异常值表明这些特征的原始分布和合成分布之间存在相当大的差异。需要注意的是,该模型没有使用不同的超参数进行评估,而是在 3 个时期内进行了微调。
4、Transformer GAN
已经有研究利用 Transformer 模型生成合成表格数据[30][31]。这篇文章介绍了将 Transformer 架构与生成对抗网络 (GAN) 相结合以生成合成表格数据的方法。这个提出的 Transformer GAN 框架利用 Transformer 作为生成器的强大自注意力机制,并与自定义鉴别器网络配对。
我们探索了两种不同的变体:原始 TransformerGAN 和条件 Transformer GAN,每种变体都提供了独特的合成数据生成功能。原始架构采用基于 Transformer 的生成器,通过多个自注意层处理输入噪声向量,而条件变体则结合了额外的上下文信息来指导生成过程。该架构还包括位置编码的集成,以帮助模型更好地理解表格数据中的顺序模式。
下图概述了 Transformer GAN 及其实现的组件:
4.1 原始 Transformer GAN
为了正式解释这一点,训练涉及两个神经网络,一个生成器和一个鉴别器,以对抗的方式进行。Transformer 充当生成器网络,表示为 G,以潜在向量 z 作为输入并生成合成数据样本 x^ = G(z):
这里 Femb 表示输入嵌入层,PE 表示位置编码,Fout 表示输出线性层。
为了生成表格数据等结构化输出,我们的模型需要使用正确的字段顺序及其序列进行训练[28]。这是通过提供相对和绝对位置信息的位置编码完成的。从上面可以看出,位置编码具有与输入嵌入层相同的维度,并且可以添加两个嵌入。对于给定的模型维度 dmodel,位置 pos 和维度 i 的位置编码[25]定义为:
在此实现中,i 的范围从 0 到 dmodel/2。这采用正弦位置编码,其中每个位置和维度对的编码都是使用交替的正弦和余弦函数计算的。对于偶数维度 (2i),编码使用正弦函数,而奇数维度 (2i+1) 使用余弦函数。对每个位置重复此操作。这种方法在我们的表格数据环境中特别有价值,因为它使模型能够捕获局部和全局位置关系。
class PositionalEncoding(nn.Module):
def __init__(self,d_model,max_positions=1024,n=10000):
super(PositionalEncoding,self).__init__()
pe = torch.zeros(max_positions*d_model).reshape(max_positions, d_model)
k = torch.arange(0,max_positions).unsqueeze(1)
i = torch.arange(d_model//2)
div_term = (n ** ((2*i)/d_model))
theta = 1/div_term
pe[:, 0::2] = torch.sin(k * theta)
pe[:, 1::2] = torch.cos(k * theta)
self.pe = pe.to(device)
def forward(self,x):
x = x + self.pe[:x.size()[0],:]
return x
鉴别器网络(表示为 D)将数据样本 x 作为输入,并预测其为实数或从集合 (0,1) 中生成的概率 D(x)。本质上,鉴别器可以定义为一系列线性变换:
其中 σ 表示 LeakyReLU 激活,但最后一层除外,该层使用 sigmoid。
class TransformerGenerator(nn.Module):
def __init__(self, input_dim, model_dim,num_heads,num_layers,feedforward_dim):
super(TransformerGenerator, self).__init__()
self.embedding = nn.Linear(input_dim,model_dim)
self.pos_encoding = PositionalEncoding(d_model=model_dim)
encoder_layer = nn.TransformerEncoderLayer(model_dim,num_heads,feedforward_dim,dropout=0.2)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers)
self.fc_out = nn.Linear(model_dim,input_dim)
def forward(self, x):
emb = self.embedding(x)
pe = self.pos_encoding(emb)
x = emb + pe
x = self.transformer_encoder(x)
return self.fc_out(x)
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator,self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
以下超参数用于训练网络。
# Set hyperparams
batch_size = 32
num_epochs = 40
lr = 0.0001
condition_dim = 3
input_dim = len(df.columns) - condition_dim
model_dim = 512
num_heads = 16
num_layers = 18
feedforward_dim = 512
训练循环将传统的 GAN 训练与专为表格数据合成而设计的额外优化技术相结合。初始训练后实施的额外策略导致模式崩溃[29]。这导致生成器损失变得非常大且为负,而鉴别器损失变小并收敛到零。为了解决这个问题,对鉴别器损失应用了标签平滑(1 - 平滑)以防止过度自信。鉴别器的损失(d_loss)计算为真实和假样本的二元交叉熵损失之和,优化了鉴别器区分真实和合成制造数据的能力。
生成器训练阶段引入了几种复杂的技术来提高合成数据的质量。输入噪声向量使用 0 到 2 之间的随机因子动态缩放,从而在生成的样本中引入了可变性。生成器的损失函数结合了两个部分:二元交叉熵损失 (g_loss_bce),鼓励生成器生成可以欺骗鉴别器的样本;Kullback-Leibler 散度项 (g_loss_kl),用于测量生成的数据分布与实际数据分布之间的统计距离。KL 散度项由 kl_div_weight 加权,以平衡其对整体损失的贡献。生成器和鉴别器优化器都利用学习率调度程序在训练期间自适应地调整学习率,从而有可能提高收敛稳定性。
for epoch in range(num_epochs):
for batch in dataloader:
real_data = batch[0].to(device)
batch_size = real_data.size(0)
optimizer_D.zero_grad()
real_labels = torch.ones(batch_size, 1,device=device)* (1 - smoothing)
real_outputs = model.discriminator(real_data)
real_loss = criterion_bce(real_outputs, real_labels)
z = torch.randn(batch_size, input_dim,device=device)
fake_data = model.generator(z).to(device)
fake_labels = torch.zeros(batch_size, 1,device=device)* smoothing
fake_outputs = model.discriminator(fake_data)
fake_loss = criterion_bce(fake_outputs, fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
optimizer_G.zero_grad()
z = torch.randn(batch_size, input_dim,device=device)
scale = torch.rand(batch_size, 1, device=device) * 2 # Random scale between 0 and 2
z = z * scale
fake_data = model.generator(z)
fake_outputs = model.discriminator(fake_data)
g_loss_bce = criterion_bce(fake_outputs, real_labels)
# KL divergence
g_loss_kl = torch.abs(criterion_kl(fake_data, real_data))
g_loss = g_loss_bce + kl_div_weight * g_loss_kl
g_loss.backward()
optimizer_G.step()
scheduler_G.step()
scheduler_D.step()
print(f"Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
4.2 条件 Transformer GAN
给定一个真实数据分布 X 和条件空间 C,以及潜在空间 Z,条件生成器 G 和鉴别器 D 定义为:
条件生成器架构结合了潜在输入和条件信息的双重嵌入策略。
其中 Femb 是输入嵌入层,Fcond 是条件嵌入层,PE 是位置编码。在特定数据集中,settings1、settings2 和 settings3 构成条件空间。此外,它还包括一个可选的基于 CNN 的嵌入路径,可以通过卷积层处理输入数据,提供替代的特征提取机制。
条件鉴别器处理真实/生成的样本和条件:
其中 [x; c] 表示输入和条件向量的连接。条件向量 c 是从输入的前三个分量中提取的。c = x1:3,x′ = x4:n
class ConditionalTransformerGenerator(nn.Module):
def __init__(self, input_dim,condition_dim, model_dim,num_heads,num_layers,feedforward_dim):
super(ConditionalTransformerGenerator, self).__init__()
self.embedding = nn.Linear(input_dim,model_dim)
self.condition_embedding = nn.Linear(condition_dim, model_dim)
self.pos_encoding = PositionalEncoding(d_model=model_dim)
encoder_layer = nn.TransformerEncoderLayer(model_dim,num_heads,feedforward_dim,dropout=0.2)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers)
self.fc_out = nn.Linear(model_dim,input_dim)
self.cnn_embedding = nn.Sequential(
nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool1d(2)
)
self.linear_proj = nn.Linear(640, model_dim)
def forward(self, x,condition):
if cnn_embeddings:
x = x.unsqueeze(1) #for CNN
emb = self.cnn_embedding(x)
emb = emb.view(emb.size(0), -1) # flatten
emb = self.linear_proj(emb)
else:
emb = self.embedding(x)
condition_emb = self.condition_embedding(condition)
x = emb + condition_emb
pe = self.pos_encoding(x)
x = x + pe
x = self.transformer_encoder(x)
return self.fc_out(x)
class ConditionalDiscriminator(nn.Module):
def __init__(self, input_dim,condition_dim):
super(ConditionalDiscriminator,self).__init__()
self.condition_dim = condition_dim
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim + condition_dim, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x,condition):
if condition.ndim == 1:
condition = condition.unsqueeze(0).repeat(x.size(0), 1)
x = torch.cat([x, condition], dim=1)
return self.model(x)
训练循环类似于 Transformer GAN。Conditional Transformer GAN 和 Vanilla Transformer GAN 训练循环之间的关键区别在于数据的处理方式。在 Conditional Transformer GAN 训练循环中,输入数据被策略性地划分为条件向量(操作设置)和目标特征(传感器值),其中前三个特征用作指导生成过程的条件。在每次训练迭代期间,这些条件向量都明确与真实数据和生成数据配对,从而影响生成器的合成过程和鉴别器的评估。反过来,鉴别器通过考虑生成的/真实的样本及其相应的条件连接在一起来评估数据的真实性,并在操作参数的背景下做出决策。这与标准的 Transformer GAN 训练循环不同,在标准 Transformer GAN 训练循环中,生成器仅对随机噪声输入进行操作,而鉴别器则在没有任何条件上下文的情况下评估数据。
5、LM- GAN
本文接下来将 Transformer GAN 的思想扩展到语言模型,特别是 SLM,并利用预先训练的 distilGPT2 语言模型作为生成器,介绍了语言模型 GAN (LM-GAN) 架构。随后的实证分析包括 KL 散度测量、主成分分析 (PCA) 和直方图,证明了 LM-GAN 的优越性。本文认为,LM GAN 架构生成的合成样本保持了原始传感器测量的统计特性和多模态特征,同时避免了模式崩溃,这是 GAN 训练中的常见挑战。
LM GAN 的公式在很大程度上类似于 Transformer GAN。此语言模型 GAN 架构中的生成器网络利用了预先训练的 DistilGPT2 模型,关键区别在于,作为生成器一部分的 distil GPT2 由 θ 参数化
其中 Z 表示输入标记空间,X 表示输出 logits 空间。生成器接受输入的 token 序列及其对应的注意力掩码,并通过由自注意力机制和前馈神经网络组成的多个转换器块对其进行处理。该模型的架构保留了 DistilGPT2 的原始配置,并修改了 token 嵌入以适应特定于制造业的词汇,包括数值和列标识符。
class Generator(nn.Module):
def __init__(self, model):
super(Generator, self).__init__()
self.model = model # pretrained distil GPT2 model
def forward(self, input_ids, attention_mask):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
#max_length = max_length
)
return outputs.logits.squeeze(1), outputs.loss
生成器输出为:
其中 Xsyn 表示合成数据对数,Lgen 表示语言建模损失。在前向传递过程中,生成器产生两个输出:表示序列中每个位置的词汇表概率分布的对数,以及通过教师强制计算的语言建模损失。对数形状为 (batch_size、sequence_length、vocabulary_size),捕获模型对每个标记位置的预测,而语言建模损失有助于保持在预训练期间学习的语言结构。生成器的参数在 GAN 训练期间进行微调,使其能够将其预学习的表示调整为制造传感器数据中存在的特定模式和分布,同时保留其生成连贯序列的能力。
具有参数 ϕ 的鉴别器 Dϕ 定义为:
其中: — e : Rn×v → Rn×d 是嵌入层 — hlstm : Rn×d → R2d是双向 LSTM[32] 。 fϕ 是分类层,σ 是S 型激活函数。鉴别器是一种混合架构,结合了嵌入层、双向 LSTM(256 个隐藏单元)和密集神经网络,用于区分真实数据和合成数据。最后的分类阶段由一系列具有 LeakyReLU 激活函数的密集层组成。在处理生成的样本时,鉴别器首先通过 argmax 操作将生成器的 logit 转换为 token ID。
class Discriminator(nn.Module):
def __init__(self, vocab_size):
super(Discriminator, self).__init__()
self.embedding = nn.Embedding(vocab_size, 128)
self.lstm = nn.LSTM(128, 256, batch_first=True, bidirectional=True)
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input_ids):
if input_ids.dim() == 3: # If input is logits (batch_size, sequence_length, vocab_size)
input_ids = torch.argmax(input_ids, dim=-1) # Convert to token ids
embedded = self.embedding(input_ids.int())
lstm_out, _ = self.lstm(embedded)
lstm_out = lstm_out[:, -1, :] # take last hidden state
validity = self.classifier(lstm_out)
return validity
训练过程利用 Adam 优化器[33],生成器和鉴别器网络的初始学习率均为 2e-5,并结合 ReduceLROnPlateau 调度器[34],根据损失轨迹动态调整学习率。
g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optimizer, mode='min', factor=0.5, patience=2)
d_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(d_optimizer, mode='min', factor=0.5, patience=2)
该过程从上文高级微调部分中介绍的数字标记器开始。它使用自定义分隔标记处理数值和列标识符,确保精确表示制造测量值。在每次训练迭代期间,真实数据序列首先被标记化并输入鉴别器,鉴别器学习为真实的制造数据模式分配高概率。生成器利用 DistilGPT2 架构生成合成序列,鉴别器评估这些序列,并使用二进制交叉熵计算对抗性损失。
for batch_idx, batch in enumerate(dataloader):
real_texts = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
batch_size = real_texts.size(0)
#Generator
g_optimizer.zero_grad()
gen_logits, causal_lm_loss = generator(real_texts, attention_mask)
fake_checks = discriminator(gen_logits)
r = real_texts.squeeze().float()
g = torch.argmax(gen_logits, dim=-1).float()
w_loss = wasserstein_loss(fake_checks)
k_loss = kth_order_loss(g, r, k=2)
g_loss = (
1.0 * causal_lm_loss + # lm loss
0.8 * w_loss + # Wasserstein loss
0.2 * k_loss # k-th order loss for numerical accuracy
)
g_loss.backward()
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
g_optimizer.step()
# Discriminator
d_optimizer.zero_grad()
real_checks = discriminator(real_texts)
fake_checks = discriminator(g)
d_loss = wasserstein_loss(fake_checks) - wasserstein_loss(real_checks)
gradient_penalty = compute_gradient_penalty(discriminator, r, g)
d_loss += 10 * gradient_penalty
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
d_optimizer.step()
d_loss.backward()
训练过程结合了生成器的教师强制,其中语言建模损失与对抗性损失相结合,以在适应制造数据分布的同时保持连贯的序列生成。 ReduceLROnPlateau 调度程序监控损失指标(耐心 2 个时期),当改进停滞时将学习率降低 0.5 倍,有助于稳定训练并防止震荡。梯度惩罚计算和梯度剪裁是 LM-GAN 训练过程中至关重要的稳定机制。梯度惩罚是通过在真实样本和生成样本之间进行插值来计算的。此外,torch.nn.utils.clip_grad_norm_ 应用于生成器和鉴别器参数,最大范数阈值为 1.0,防止梯度爆炸。这种双重优化过程与专门的标记化策略相结合,使模型能够学习数据的统计特性和不同传感器测量之间的潜在关系,而调度机制确保整个训练过程中两个网络的稳健收敛。
合成数据和原始制造数据之间的比较分析表明,LM-GAN 模型在捕捉传感器测量的统计特性和潜在模式方面具有出色的能力。分布图揭示了所有 21 个传感器通道中真实数据和合成数据之间的一致性,其中合成数据(以橙色虚线显示)紧跟原始测量(蓝色实线)的时间动态。
直方图比较显示出不同传感器范围内的一致分布匹配,在操作设置(s1-s3)和关键传感器测量中尤为明显。
值得注意的是,PCA 可视化展示了数据流形的出色保存,合成样本(红点)与主成分中的真实数据点(蓝点)紧密混合,表明该模型已成功捕获不同传感器测量之间的复杂相关性。
这项综合评估验证了 LM-GAN 架构不仅保留了各个传感器的边际分布,而且还保持了制造过程数据中存在的复杂关系和操作模式。
6、消融研究
为了定量评估合成数据生成方法的保真度,我们进行了全面的消融研究,测量原始数据集和合成数据集之间的统计相似性。分析采用三个主要评估框架:
比较原始数据集和合成数据集之间的统计测量(最小值、最大值、平均值和标准差):
聚合保真度得分 Kolmogorov-Smirnov 检验和 Kolomogorov Smirnov 得分:
通过相关矩阵比较评估特征关系的保存。以下是 LM-GAN 架构生成的原始分布和合成分布之间的相关矩阵:
7、结束语
这个由两部分组成的研究系列对工业制造应用的合成数据生成技术进行了全面分析。第一部分通过评估包括 GAN、VAE、高斯 Copula、贝叶斯网络和 CTGAN 在内的传统生成模型建立了基线性能指标,展示了贝叶斯网络和高斯 Copula 等概率方法在保留统计分布方面的卓越性能。
第二部分通过介绍用于合成数据生成的小型语言模型 (SLM) 的新应用来推动该领域的发展。研究从基础的即时工程发展到像 LM-GAN 这样的复杂架构,展示了合成数据保真度不断发展的能力。比较分析表明,使用自定义损失函数的高级微调显著改善了基本的 SLM 方法,而提出的 LM-GAN 架构在保留边际分布和复杂的特征间关系方面实现了最先进的性能。
两项研究的主要发现表明:
- 传统概率模型为制造数据合成提供了强大的基线性能
- 基于 SLM 的方法在捕获特定领域约束方面提供了增强的灵活性
- 将语言模型与对抗性训练相结合的混合架构在保持统计特性方面表现出卓越的性能
- 提出的 LM-GAN 框架成功解决了传统 GAN 中常见的模式崩溃问题,同时保持了数据保真度
汇智网翻译整理,转载请标明出处