Support multilingual whisper models (#274)
This commit is contained in:
@@ -11,6 +11,7 @@ for making the onnx export script public.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -250,6 +251,7 @@ def main():
|
||||
# write tokens
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||
|
||||
model.eval()
|
||||
print(model.dims)
|
||||
audio = torch.rand(16000 * 2)
|
||||
@@ -306,8 +308,12 @@ def main():
|
||||
"n_text_head": model.dims.n_text_head,
|
||||
"n_text_layer": model.dims.n_text_layer,
|
||||
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
||||
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
|
||||
"all_language_codes": ",".join(tokenizer.all_language_codes),
|
||||
"all_language_tokens": ",".join(
|
||||
list(map(str, tokenizer.all_language_tokens))
|
||||
), # a list of ids
|
||||
"all_language_codes": ",".join(
|
||||
tokenizer.all_language_codes
|
||||
), # e.g., en, de, zh, fr
|
||||
"sot": tokenizer.sot,
|
||||
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
||||
"eot": tokenizer.eot,
|
||||
@@ -413,6 +419,9 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
if 'large' in args.model:
|
||||
# it causes errors for large models, so skip it.
|
||||
return
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
|
||||
Reference in New Issue
Block a user