5个Whisper变体实现与比较
最近,我研究自动语音识别 (ASR),以便从语音数据中进行转录。说到开源 ASR 模型,OpenAI 开发的 Whisper [1] 可能是最佳选择,因为它的转录准确度很高。但是,Whisper 有很多变体,所以我想比较一下它们的特点。
在这篇博客中,我将快速回顾 Whisper,并介绍变体以及如何在 Python 中实现它们。我将解释 vanilla Whisper、Faster Whisper、Whisper X、Distil-Whisper 和 Whisper-Medusa。
1、什么是 Whisper?
Whisper [1] 是 OpenAI 开发的自动语音识别 (ASR) 模型。它基于 680,000 小时的多语言和多任务监督数据进行训练,包括转录、翻译、语音活动检测、对齐和语言识别。在 Whisper 出现之前,还没有以监督方式训练过如此大量数据的模型。关于架构,Whisper 采用了编码器-解码器 Transformer 以实现可扩展性。架构图如下所示。
首先,Whisper 将音频数据转换为对数梅尔频谱图。对数梅尔频谱图是梅尔标度中信号频率频谱的视觉表示,常用于语音处理和机器学习任务。有关更多信息,你可以查看此博客 [2]。在 Whisper 将对数梅尔频谱图输入到一些 1-D 卷积层和位置编码后,它会以类似于自然语言处理 Transformer 的方式处理数据。 Whisper 可以在多语言环境中工作,以利用 GPT-2 使用的字节级 BPE 标记器。得益于多任务学习,Whisper 还可以执行转录、时间戳检测和翻译。
官方 Whisper 有六种模型大小,其中四种只有英文版本,提供速度和准确性的权衡。较小的模型仅提供英文功能。
就在最近(2024/10),OpenAI 发布了新版本“turbo”,它具有与大尺寸模型几乎相同的功能,但通过微调修剪后的大尺寸模型,提供了显着的速度提升(8 倍!)。所有 Whisper 模型都与 HuggingFace Transformer 库兼容。
现在,我们快速回顾一下 Whisper。它基于编码器-解码器 Transformer 架构,性能出色,甚至包括在商业模型中。在下一节中,我们将讨论 Whisper 变体。
2、Whisper 变体
在本节中,我们将介绍 Whisper 变体及其功能。我重点介绍 Python 和 Pytorch 实现。虽然 Whisper.cpp 和 Whisper JAX 是流行的变体,但我不会研究它们。此外,Whisper-streaming 也是实时推理的流行变体,但它需要高端 GPU,所以我也不会讨论它。我们将检查 Faster-Whisper、Whisper X、Distil-Whisper 和 Whisper-Medusa。
2.1 Faster-Whisper
Faster-Whisper 是使用 CTranslate2 对 Whisper 的重新实现,CTranslate2 是一个 C++ 和 Python 库,用于使用 Transformer 模型进行高效推理。因此,架构没有变化。
根据官方存储库,Faster-Whisper 的速度可以比原始实现快约 4 倍,同时保持相同的准确率,同时占用更少的内存。简而言之,Ctranslate2 采用了许多优化技术,例如权重量化、层融合、批量重新排序等。我们可以根据机器类型选择类型选项,例如 float16 或 int8;例如,当我们选择 int8 时,我们甚至可以在 CPU 上运行 Whisper。
2.2 WhisperX
WhisperX [3] 也是一个集成 Faster-Whisper 的高效语音转录系统。虽然 vanilla Whisper 经过多项任务训练,包括时间戳预测,但它很容易对单词级时间戳不准确。此外,由于其顺序推理性质,vanilla Whisper 通常需要计算时间来处理长格式的音频输入。为了克服这些弱点,WhisperX 引入了三个额外的阶段:(1)语音活动检测 (VAD),(2)VAD 的剪切和合并结果,以及(3)与外部音素模型强制对齐以提供准确的单词级时间戳。架构图如下所示:
首先,WhisperX 通过 VAD 层处理输入音频。顾名思义,VAD 检测语音片段。WhisperX 利用 pyannote-audio 库中的分段模型进行 VAD。接下来,WhisperX 剪切和合并检测到的语音分段。此过程允许我们根据每个剪切结果运行批量推理。最后,WhisperX 应用强制对齐来测量单词级准确的时间戳。让我们看一个具体的例子,如下所示:
它利用 Whisper 进行转录,利用 Phoneme 模型进行音素级转录。音素模型可以检测每个音素的时间戳;因此,如果我们从 Whisper 转录中的下一个最近的音素分配时间戳,我们可以为每个单词获得更准确的时间戳。
尽管与 vanilla Whisper 相比,WhisperX 增加了三个额外的过程,但由于批量推理,它可以有效地转录更长的音频。下表显示了性能比较,你可以检查 WhisperX 保持较低的 WER 但提高了推理速度:
2.3 Distil-Whisper
Distil-Whisper [4] 由 HuggingFace 于 2023 年开发。它是一个使用知识蒸馏压缩 Whipser Large 模型的模型。它利用常识蒸馏技术来训练较小的模型,例如来自 Whisper Large 模型的伪标签和 Kullback-Leibler 散度损失。架构图如下所示:
该架构与 vanilla Whisper 配对,但层数减少了。对于数据集,作者从互联网上收集了 21,170 小时的公开数据来训练 Distil-Whisper。Distil-Whisper 的速度比 Whisper Large 模型快 5.8 倍,参数减少了 51%,同时在分布外数据上的字错误率 (WER) 为 1%。下表显示了性能比较:
如你所见,Distil-Whisper 将字错误率保持在与 vanilla Whisper 一样低的水平,但可以降低延迟。
2.4 Whisper-Medusa
Whisper-Medusa [5] 是利用 Medusa 提高 Whisper 推理速度的变体。Medusa 是一种高效的 LLM 推理方法,它增加了额外的解码头来并行预测多个后续 token。你可以使用下图很好地理解。
在左侧部分,Medusa 有三个额外的头来预测后续 token。如果原始模型输出 y1 token,则三个额外的头会预测 y2、y3 和 y4 token。Medusa 可以通过添加额外的头来增加预测数量,并减少整体推理时间。请注意,由于增加了头,所需的 VRAM 量会增加。
Whisper-Medusa 将 Medusa 的理念应用于 Whisper,如右侧所示。由于 Whisper 具有顺序推理特性,因此推理速度较慢,而 Medusa 的功能有助于加快推理速度。Whisper-Medusa 和 vanilla Whisper 之间的比较结果如下所示:
对于几种语言数据集,Whisper-Medusa 记录的单词错误率 (WER) 低于 vanilla Whisper。它的平均速度也可以提高 1.5 倍。
在本节中,我们将检查 Whisper 变体及其功能。下一节将探讨如何在 Python 中实现它们并检查它们对真实音频的能力。
3、Whisper 变体的实现
在本节中,我们将学习如何在 Python 中实现 Whisper 和 Whisper 变体。对于真实音频数据,我将使用手动下载的这个 YouTube 视频中的音频。视频大小约为 14 分钟。稍后我将附上如何将 mp4 文件转换为 mp3 文件的代码。
3.1 环境设置
由于库不兼容,我们创建了两个环境:一个用于 Whipser、Faster-Whisper、WhisperX 和 Distil-Whisper,另一个用于 Whisper-Medusa。
对于前一个环境,我使用了带有 Python 3.10 的 conda 环境。我在 Ubuntu 20.04 上使用 cuda 12.0、16 GB VRAM 进行了实验:
conda create -n audioenv python=3.10 -y
conda activate audioenv
接下来,我们需要通过 pip 和 conda 安装下面的库。完成下面的安装后,你需要将 numpy 降级到 1.26.3:
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install python-dotenv moviepy openai-whisper accelerate datasets[audio]
pip install numpy==1.26.3
接下来,我们需要安装 whisperX 存储库。但是,到目前为止,whisperX 不再经常维护。因此,我们使用名为 BetterWhisperX 的分叉存储库:
git clone https://github.com/federicotorrielli/BetterWhisperX.git
cd BetterWhisperX
pip install -e .
第一个环境准备已完成。
对于 Whisper-Medusa 环境,我使用了带有 Python 3.11 的 conda 环境。我还在 Ubuntu 20.04 上进行了实验,使用 cuda 12.0、24 GB VRAM:
conda create -n medusa python=3.11 -y
conda activate medusa
你需要通过 pip 安装以下库:
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
pip install wandb
git clone https://github.com/aiola-lab/whisper-medusa.git
cd whisper-medusa
pip install -e .
所有准备工作都已完成。现在,让我们检查一下 Whisper 变体的功能!
3.2 Whisper turbo
我们使用最新版本的 Whisper turbo。感谢官方存储库,我们只需几行代码即可实现 vanilla Whisper。
import whisper
model = whisper.load_model("turbo")
result = model.transcribe("audio.mp3")
Whisper 只能处理 30 秒内的音频数据,但 transcribe 方法会读取整个文件并使用滑动的 30 秒窗口处理音频,因此我们不关心如何提供数据。
3.3 Faster-Whisper
我们使用 Faster-Whisper 的 Whisper turbo 主干。Faster-Whisper 有原始存储库,我们可以按如下方式实现它:
from faster_whisper import WhisperModel
model_size = "deepdml/faster-whisper-large-v3-turbo-ct2"
# Run on GPU with FP16
model = WhisperModel(model_size_or_path=model_size, device="cuda", compute_type="float16")
segments, info = model.transcribe('audio.mp3', beam_size=5)
beam_size
用于解码时的波束搜索。由于 Faster-Whisper 的功能与 vanilla Whisper 相同,我们可以使用滑动窗口处理长格式音频。
3.4 WhisperX
我们使用 WhisperX 的 Whisper turbo 主干。由于 WhisperX 使用 Faster-Whisper 作为主干,因此部分代码是共享的:
import whisperx
model_size = "deepdml/faster-whisper-large-v3-turbo-ct2"
# Transcribe with original whisper (batched)
model = whisperx.load_model(model_size, 'cuda', compute_type="float16")
model_a, metadata = whisperx.load_align_model(language_code='en', device='cuda')
# inference
audio = whisperx.load_audio('audio.mp3')
whisper_result = model.transcribe(audio, batch_size=16)
result = whisperx.align(whisper_result["segments"], model_a, metadata, audio, 'cuda', return_char_alignments=False)
WhisperX 与 Faster-Whisper 集成,并添加了处理 VAD 和强制对齐的附加层。由于剪切和合并,我们还可以处理超过 30 秒的长格式音频。
3.5 Distil-Whisper
我们将使用 large-v3 模型的精简版本,因为最新的 turbo 版本尚未发布。Distil-Whisper 与 HuggingFace Transformer 库兼容,因此我们可以轻松实现它:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
return_timestamps=True
)
result = pipe('audio.mp3')
pipeline 类使用滑动窗口自动处理长格式音频。请注意,此方法仅输出相对时间戳。
3.6 Whisper-Medusa
我们使用大型模型作为 Whisper 主干。按照官方的实现,我们可以按如下方式实现:
import torch
import torchaudio
from whisper_medusa import WhisperMedusaModel
from transformers import WhisperProcessor
SAMPLING_RATE = 16000
language = "en"
regulation_factor=1.01
regulation_start=140
device = 'cuda'
model_name = "aiola/whisper-medusa-linear-libri"
model = WhisperMedusaModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)
model = model.to(device)
input_speech, sr = torchaudio.load(audio_path)
if input_speech.shape[0] > 1: # If stereo, average the channels
input_speech = input_speech.mean(dim=0, keepdim=True)
if sr != SAMPLING_RATE:
input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
exponential_decay_length_penalty = (regulation_start, regulation_factor)
input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
input_features = input_features.to(device)
model_output = model.generate(
input_features,
language=language,
exponential_decay_length_penalty=exponential_decay_length_penalty,
)
predict_ids = model_output[0]
pred = processor.decode(predict_ids, skip_special_tokens=True)
遗憾的是,Whisper-Medusa 目前不支持长格式音频转录,因此我们只能将其用于最长 30 秒的音频数据。当我检查 30 秒转录的质量时,它不如其他变体那么好。因此,我跳过了与其他 Whisper 变体的比较结果。
4、Whisper 变体性能比较
正如我之前提到的,我使用大约 14 分钟的音频文件作为输入。下表比较了每个模型的结果:
总结一下,
- Whisper turbo 有时倾向于把相同的句子和幻觉放在一起。
- Faster-Whisper 转录几乎不错,计算速度最好。
- WhisperX 转录最好,它记录了一个非常准确的时间戳。
- Distil-Whisper 转录几乎不错。但是,它只记录相对时间戳。
如果你可以允许细微的错误转录,并且不关心时间戳,那么你应该使用 Faster-Whisper。同时,如果你想知道准确的时间戳和转录,你应该使用 WhisperX。
WhisperX 和 Faster-Whisper 可以比 vanilla Whisper 获得更好的结果,可能是因为 Faster-Whisper 有波束搜索以获得更好的推理结果,而 Whisper X 有强制对齐。因此,他们有机会在后处理中修复错误转录。
5、结束语
在这篇博客中,我们了解了 Whisper 变体的架构及其在 Python 中的实现。许多研究人员使用各种优化技术来最大限度地降低实际应用的推理速度。根据我的调查,Faster-Whisper 和 WhisperX 保留了该功能,但成功降低了推理速度。
这是我在这次实验中使用的代码笔记本。
汇智网翻译整理,转载请标明出处