Support multilingual whisper models (#274)
This commit is contained in:
@@ -11,6 +11,7 @@ for making the onnx export script public.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -250,6 +251,7 @@ def main():
|
||||
# write tokens
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||
|
||||
model.eval()
|
||||
print(model.dims)
|
||||
audio = torch.rand(16000 * 2)
|
||||
@@ -306,8 +308,12 @@ def main():
|
||||
"n_text_head": model.dims.n_text_head,
|
||||
"n_text_layer": model.dims.n_text_layer,
|
||||
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
||||
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
|
||||
"all_language_codes": ",".join(tokenizer.all_language_codes),
|
||||
"all_language_tokens": ",".join(
|
||||
list(map(str, tokenizer.all_language_tokens))
|
||||
), # a list of ids
|
||||
"all_language_codes": ",".join(
|
||||
tokenizer.all_language_codes
|
||||
), # e.g., en, de, zh, fr
|
||||
"sot": tokenizer.sot,
|
||||
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
||||
"eot": tokenizer.eot,
|
||||
@@ -413,6 +419,9 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
if 'large' in args.model:
|
||||
# it causes errors for large models, so skip it.
|
||||
return
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
|
||||
@@ -38,6 +38,24 @@ def get_args():
|
||||
help="Path to the tokens",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
help="""The actual spoken language in the audio.
|
||||
Example values, en, de, zh, jp, fr.
|
||||
If None, we will detect the language using the first 30s of the
|
||||
input audio
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
default="transcribe",
|
||||
help="Valid values are: transcribe, translate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
@@ -74,12 +92,22 @@ class OnnxModel:
|
||||
self.sot = int(meta["sot"])
|
||||
self.eot = int(meta["eot"])
|
||||
self.translate = int(meta["translate"])
|
||||
self.transcribe = int(meta["transcribe"])
|
||||
self.no_timestamps = int(meta["no_timestamps"])
|
||||
self.no_speech = int(meta["no_speech"])
|
||||
self.blank = int(meta["blank_id"])
|
||||
|
||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||
|
||||
self.sot_sequence.append(self.no_timestamps)
|
||||
|
||||
self.all_language_tokens = list(
|
||||
map(int, meta["all_language_tokens"].split(","))
|
||||
)
|
||||
self.all_language_codes = meta["all_language_codes"].split(",")
|
||||
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
|
||||
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
|
||||
|
||||
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
||||
|
||||
def init_decoder(self, decoder: str):
|
||||
@@ -164,6 +192,29 @@ class OnnxModel:
|
||||
# logits is changed in-place
|
||||
logits[self.translate] = float("-inf")
|
||||
|
||||
def detect_language(
|
||||
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
|
||||
) -> int:
|
||||
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
|
||||
offset = torch.zeros(1, dtype=torch.int64)
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
|
||||
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
|
||||
tokens=tokens,
|
||||
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||
n_layer_cross_k=n_layer_cross_k,
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
logits = logits.reshape(-1)
|
||||
mask = torch.ones(logits.shape[0], dtype=torch.int64)
|
||||
mask[self.all_language_tokens] = 0
|
||||
logits[mask] = float("-inf")
|
||||
lang_id = logits.argmax().item()
|
||||
print("detected language: ", self.id2lang[lang_id])
|
||||
return lang_id
|
||||
|
||||
|
||||
def load_tokens(filename):
|
||||
tokens = dict()
|
||||
@@ -200,7 +251,35 @@ def main():
|
||||
mel = mel.t().unsqueeze(0)
|
||||
|
||||
model = OnnxModel(encoder, decoder)
|
||||
|
||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||
|
||||
if args.language is not None:
|
||||
if model.is_multilingual is False and args.language != "en":
|
||||
print(f"This model supports only English. Given: {args.language}")
|
||||
return
|
||||
|
||||
if args.language not in model.lang2id:
|
||||
print(f"Invalid language: {args.language}")
|
||||
print(f"Valid values are: {list(model.lang2id.keys())}")
|
||||
return
|
||||
|
||||
# [sot, lang, task, notimestamps]
|
||||
model.sot_sequence[1] = model.lang2id[args.language]
|
||||
elif model.is_multilingual is True:
|
||||
print("detecting language")
|
||||
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
|
||||
model.sot_sequence[1] = lang
|
||||
|
||||
if args.task is not None:
|
||||
if model.is_multilingual is False and args.task != "transcribe":
|
||||
print("This model supports only English. Please use --task=transcribe")
|
||||
return
|
||||
assert args.task in ["transcribe", "translate"], args.task
|
||||
|
||||
if args.task == "translate":
|
||||
model.sot_sequence[2] = model.translate
|
||||
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||
|
||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||
@@ -213,6 +292,7 @@ def main():
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
offset += len(model.sot_sequence)
|
||||
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=True)
|
||||
@@ -225,7 +305,6 @@ def main():
|
||||
break
|
||||
results.append(max_token_id.item())
|
||||
tokens = torch.tensor([[results[-1]]])
|
||||
offset += 1
|
||||
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||
tokens=tokens,
|
||||
@@ -235,6 +314,7 @@ def main():
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
offset += 1
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=False)
|
||||
max_token_id = logits.argmax(dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user