Support multilingual whisper models (#274)

This commit is contained in:
Fangjun Kuang
2023-08-16 00:28:52 +08:00
committed by GitHub
parent 496c5dd7f5
commit f709c95c5f
24 changed files with 692 additions and 73 deletions

View File

@@ -11,6 +11,7 @@ for making the onnx export script public.
"""
import argparse
import os
from pathlib import Path
from typing import Any, Dict, Optional
@@ -250,6 +251,7 @@ def main():
# write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
model.eval()
print(model.dims)
audio = torch.rand(16000 * 2)
@@ -306,8 +308,12 @@ def main():
"n_text_head": model.dims.n_text_head,
"n_text_layer": model.dims.n_text_layer,
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
"all_language_codes": ",".join(tokenizer.all_language_codes),
"all_language_tokens": ",".join(
list(map(str, tokenizer.all_language_tokens))
), # a list of ids
"all_language_codes": ",".join(
tokenizer.all_language_codes
), # e.g., en, de, zh, fr
"sot": tokenizer.sot,
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
"eot": tokenizer.eot,
@@ -413,6 +419,9 @@ def main():
},
)
if 'large' in args.model:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

View File

@@ -38,6 +38,24 @@ def get_args():
help="Path to the tokens",
)
parser.add_argument(
"--language",
type=str,
help="""The actual spoken language in the audio.
Example values, en, de, zh, jp, fr.
If None, we will detect the language using the first 30s of the
input audio
""",
)
parser.add_argument(
"--task",
choices=["transcribe", "translate"],
type=str,
default="transcribe",
help="Valid values are: transcribe, translate",
)
parser.add_argument(
"sound_file",
type=str,
@@ -74,12 +92,22 @@ class OnnxModel:
self.sot = int(meta["sot"])
self.eot = int(meta["eot"])
self.translate = int(meta["translate"])
self.transcribe = int(meta["transcribe"])
self.no_timestamps = int(meta["no_timestamps"])
self.no_speech = int(meta["no_speech"])
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
map(int, meta["all_language_tokens"].split(","))
)
self.all_language_codes = meta["all_language_codes"].split(",")
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
self.is_multilingual = int(meta["is_multilingual"]) == 1
def init_decoder(self, decoder: str):
@@ -164,6 +192,29 @@ class OnnxModel:
# logits is changed in-place
logits[self.translate] = float("-inf")
def detect_language(
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
) -> int:
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
logits = logits.reshape(-1)
mask = torch.ones(logits.shape[0], dtype=torch.int64)
mask[self.all_language_tokens] = 0
logits[mask] = float("-inf")
lang_id = logits.argmax().item()
print("detected language: ", self.id2lang[lang_id])
return lang_id
def load_tokens(filename):
tokens = dict()
@@ -200,7 +251,35 @@ def main():
mel = mel.t().unsqueeze(0)
model = OnnxModel(encoder, decoder)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
if args.language is not None:
if model.is_multilingual is False and args.language != "en":
print(f"This model supports only English. Given: {args.language}")
return
if args.language not in model.lang2id:
print(f"Invalid language: {args.language}")
print(f"Valid values are: {list(model.lang2id.keys())}")
return
# [sot, lang, task, notimestamps]
model.sot_sequence[1] = model.lang2id[args.language]
elif model.is_multilingual is True:
print("detecting language")
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
model.sot_sequence[1] = lang
if args.task is not None:
if model.is_multilingual is False and args.task != "transcribe":
print("This model supports only English. Please use --task=transcribe")
return
assert args.task in ["transcribe", "translate"], args.task
if args.task == "translate":
model.sot_sequence[2] = model.translate
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
@@ -213,6 +292,7 @@ def main():
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
offset += len(model.sot_sequence)
# logits.shape (batch_size, tokens.shape[1], vocab_size)
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=True)
@@ -225,7 +305,6 @@ def main():
break
results.append(max_token_id.item())
tokens = torch.tensor([[results[-1]]])
offset += 1
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens,
@@ -235,6 +314,7 @@ def main():
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
offset += 1
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=False)
max_token_id = logits.argmax(dim=-1)