Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)
This commit is contained in:
@@ -32,6 +32,9 @@ from whisper.model import (
|
||||
TextDecoder,
|
||||
)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -43,8 +46,9 @@ def get_args():
|
||||
choices=[
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2",
|
||||
"large", "large-v1", "large-v2", "large-v3",
|
||||
"distil-medium.en", "distil-small.en", "distil-large-v2",
|
||||
# "distil-large-v3", # distil-large-v3 is not supported!
|
||||
# for fine-tuned models from icefall
|
||||
"medium-aishell",
|
||||
],
|
||||
@@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
if "large" in filename:
|
||||
external_filename = filename.split(".onnx")[0]
|
||||
onnx.save(
|
||||
model,
|
||||
filename,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=external_filename + ".weights",
|
||||
)
|
||||
else:
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
||||
@@ -376,7 +394,9 @@ def main():
|
||||
|
||||
# write tokens
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages
|
||||
)
|
||||
|
||||
model.eval()
|
||||
print(model.dims)
|
||||
@@ -384,10 +404,15 @@ def main():
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
assert audio.shape == (16000 * 30,), audio.shape
|
||||
|
||||
# make log-Mel spectrogram and move to the same device as the model
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
|
||||
if args.model in ("large", "large-v3"):
|
||||
n_mels = 128
|
||||
else:
|
||||
n_mels = 80
|
||||
mel = (
|
||||
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
|
||||
)
|
||||
batch_size = 1
|
||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
|
||||
|
||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||
|
||||
@@ -546,6 +571,17 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
if "large" in args.model:
|
||||
decoder_external_filename = decoder_filename.split(".onnx")[0]
|
||||
decoder_model = onnx.load(decoder_filename)
|
||||
onnx.save(
|
||||
decoder_model,
|
||||
decoder_filename,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=decoder_external_filename + ".weights",
|
||||
)
|
||||
|
||||
if "large" in args.model:
|
||||
# it causes errors for large models, so skip it.
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user