Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)

This commit is contained in:
Fangjun Kuang
2024-07-12 23:47:39 +08:00
committed by GitHub
parent d928f77d0e
commit 117cd7bb8c
23 changed files with 152 additions and 85 deletions

View File

@@ -32,6 +32,9 @@ from whisper.model import (
TextDecoder,
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
@@ -43,8 +46,9 @@ def get_args():
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2",
"large", "large-v1", "large-v2", "large-v3",
"distil-medium.en", "distil-small.en", "distil-large-v2",
# "distil-large-v3", # distil-large-v3 is not supported!
# for fine-tuned models from icefall
"medium-aishell",
],
@@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
if "large" in filename:
external_filename = filename.split(".onnx")[0]
onnx.save(
model,
filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_filename + ".weights",
)
else:
onnx.save(model, filename)
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
@@ -376,7 +394,9 @@ def main():
# write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
model.eval()
print(model.dims)
@@ -384,10 +404,15 @@ def main():
audio = whisper.pad_or_trim(audio)
assert audio.shape == (16000 * 30,), audio.shape
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
if args.model in ("large", "large-v3"):
n_mels = 128
else:
n_mels = 80
mel = (
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
)
batch_size = 1
assert mel.shape == (batch_size, 80, 30 * 100)
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
@@ -546,6 +571,17 @@ def main():
},
)
if "large" in args.model:
decoder_external_filename = decoder_filename.split(".onnx")[0]
decoder_model = onnx.load(decoder_filename)
onnx.save(
decoder_model,
decoder_filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=decoder_external_filename + ".weights",
)
if "large" in args.model:
# it causes errors for large models, so skip it.
return