Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)

This commit is contained in:
Fangjun Kuang
2024-07-12 23:47:39 +08:00
committed by GitHub
parent d928f77d0e
commit 117cd7bb8c
23 changed files with 152 additions and 85 deletions

View File

@@ -2,3 +2,9 @@
*.config
*.ort
*-tokens.txt
*.bias
*.weights
*.weight
*.*embedding
_Const*
onnx__*

View File

@@ -32,6 +32,9 @@ from whisper.model import (
TextDecoder,
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
@@ -43,8 +46,9 @@ def get_args():
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2",
"large", "large-v1", "large-v2", "large-v3",
"distil-medium.en", "distil-small.en", "distil-large-v2",
# "distil-large-v3", # distil-large-v3 is not supported!
# for fine-tuned models from icefall
"medium-aishell",
],
@@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
if "large" in filename:
external_filename = filename.split(".onnx")[0]
onnx.save(
model,
filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_filename + ".weights",
)
else:
onnx.save(model, filename)
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
@@ -376,7 +394,9 @@ def main():
# write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
model.eval()
print(model.dims)
@@ -384,10 +404,15 @@ def main():
audio = whisper.pad_or_trim(audio)
assert audio.shape == (16000 * 30,), audio.shape
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
if args.model in ("large", "large-v3"):
n_mels = 128
else:
n_mels = 80
mel = (
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
)
batch_size = 1
assert mel.shape == (batch_size, 80, 30 * 100)
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
@@ -546,6 +571,17 @@ def main():
},
)
if "large" in args.model:
decoder_external_filename = decoder_filename.split(".onnx")[0]
decoder_model = onnx.load(decoder_filename)
onnx.save(
decoder_model,
decoder_filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=decoder_external_filename + ".weights",
)
if "large" in args.model:
# it causes errors for large models, so skip it.
return

View File

@@ -9,9 +9,10 @@ import base64
from typing import Tuple
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
import torchaudio
def get_args():
@@ -98,7 +99,6 @@ class OnnxModel:
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
@@ -226,7 +226,18 @@ def load_tokens(filename):
return tokens
def compute_features(filename: str) -> torch.Tensor:
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
"""
Args:
filename:
@@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
Returns:
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
"""
wave, sample_rate = torchaudio.load(filename)
audio = wave[0].contiguous() # only use the first channel
wave, sample_rate = load_audio(filename)
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio, orig_freq=sample_rate, new_freq=16000
)
import librosa
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
features = []
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
online_whisper_fbank.accept_waveform(16000, audio.numpy())
opts = knf.WhisperFeatureOptions()
opts.dim = dim
online_whisper_fbank = knf.OnlineWhisperFbank(opts)
online_whisper_fbank.accept_waveform(16000, wave)
online_whisper_fbank.input_finished()
for i in range(online_whisper_fbank.num_frames_ready):
f = online_whisper_fbank.get_frame(i)
@@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
def main():
args = get_args()
mel = compute_features(args.sound_file)
model = OnnxModel(args.encoder, args.decoder)
dim = 80 if "large-v3" not in args.encoder else 128
mel = compute_features(args.sound_file, dim=dim)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
@@ -313,6 +327,7 @@ def main():
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
print(model.sot_sequence)
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(