Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)
This commit is contained in:
6
scripts/whisper/.gitignore
vendored
6
scripts/whisper/.gitignore
vendored
@@ -2,3 +2,9 @@
|
||||
*.config
|
||||
*.ort
|
||||
*-tokens.txt
|
||||
*.bias
|
||||
*.weights
|
||||
*.weight
|
||||
*.*embedding
|
||||
_Const*
|
||||
onnx__*
|
||||
|
||||
@@ -32,6 +32,9 @@ from whisper.model import (
|
||||
TextDecoder,
|
||||
)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -43,8 +46,9 @@ def get_args():
|
||||
choices=[
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2",
|
||||
"large", "large-v1", "large-v2", "large-v3",
|
||||
"distil-medium.en", "distil-small.en", "distil-large-v2",
|
||||
# "distil-large-v3", # distil-large-v3 is not supported!
|
||||
# for fine-tuned models from icefall
|
||||
"medium-aishell",
|
||||
],
|
||||
@@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
if "large" in filename:
|
||||
external_filename = filename.split(".onnx")[0]
|
||||
onnx.save(
|
||||
model,
|
||||
filename,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=external_filename + ".weights",
|
||||
)
|
||||
else:
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
||||
@@ -376,7 +394,9 @@ def main():
|
||||
|
||||
# write tokens
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages
|
||||
)
|
||||
|
||||
model.eval()
|
||||
print(model.dims)
|
||||
@@ -384,10 +404,15 @@ def main():
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
assert audio.shape == (16000 * 30,), audio.shape
|
||||
|
||||
# make log-Mel spectrogram and move to the same device as the model
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
|
||||
if args.model in ("large", "large-v3"):
|
||||
n_mels = 128
|
||||
else:
|
||||
n_mels = 80
|
||||
mel = (
|
||||
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
|
||||
)
|
||||
batch_size = 1
|
||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
|
||||
|
||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||
|
||||
@@ -546,6 +571,17 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
if "large" in args.model:
|
||||
decoder_external_filename = decoder_filename.split(".onnx")[0]
|
||||
decoder_model = onnx.load(decoder_filename)
|
||||
onnx.save(
|
||||
decoder_model,
|
||||
decoder_filename,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=decoder_external_filename + ".weights",
|
||||
)
|
||||
|
||||
if "large" in args.model:
|
||||
# it causes errors for large models, so skip it.
|
||||
return
|
||||
|
||||
@@ -9,9 +9,10 @@ import base64
|
||||
from typing import Tuple
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -98,7 +99,6 @@ class OnnxModel:
|
||||
self.blank = int(meta["blank_id"])
|
||||
|
||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||
|
||||
self.sot_sequence.append(self.no_timestamps)
|
||||
|
||||
self.all_language_tokens = list(
|
||||
@@ -226,7 +226,18 @@ def load_tokens(filename):
|
||||
return tokens
|
||||
|
||||
|
||||
def compute_features(filename: str) -> torch.Tensor:
|
||||
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||
data, sample_rate = sf.read(
|
||||
filename,
|
||||
always_2d=True,
|
||||
dtype="float32",
|
||||
)
|
||||
data = data[:, 0] # use only the first channel
|
||||
samples = np.ascontiguousarray(data)
|
||||
return samples, sample_rate
|
||||
|
||||
|
||||
def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
@@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
Returns:
|
||||
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
|
||||
"""
|
||||
wave, sample_rate = torchaudio.load(filename)
|
||||
audio = wave[0].contiguous() # only use the first channel
|
||||
wave, sample_rate = load_audio(filename)
|
||||
if sample_rate != 16000:
|
||||
audio = torchaudio.functional.resample(
|
||||
audio, orig_freq=sample_rate, new_freq=16000
|
||||
)
|
||||
import librosa
|
||||
|
||||
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
|
||||
sample_rate = 16000
|
||||
|
||||
features = []
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||
online_whisper_fbank.accept_waveform(16000, audio.numpy())
|
||||
opts = knf.WhisperFeatureOptions()
|
||||
opts.dim = dim
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(opts)
|
||||
online_whisper_fbank.accept_waveform(16000, wave)
|
||||
online_whisper_fbank.input_finished()
|
||||
for i in range(online_whisper_fbank.num_frames_ready):
|
||||
f = online_whisper_fbank.get_frame(i)
|
||||
@@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
mel = compute_features(args.sound_file)
|
||||
model = OnnxModel(args.encoder, args.decoder)
|
||||
dim = 80 if "large-v3" not in args.encoder else 128
|
||||
mel = compute_features(args.sound_file, dim=dim)
|
||||
|
||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||
|
||||
@@ -313,6 +327,7 @@ def main():
|
||||
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||
|
||||
print(model.sot_sequence)
|
||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||
offset = torch.zeros(1, dtype=torch.int64)
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||
|
||||
Reference in New Issue
Block a user