This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/scripts/whisper/test.py
2023-12-07 17:47:08 +08:00

359 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
Please first run ./export-onnx.py
before you run this script
"""
import argparse
import base64
from typing import Tuple
import kaldi_native_fbank as knf
import onnxruntime as ort
import torch
import torchaudio
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder",
type=str,
required=True,
help="Path to the encoder",
)
parser.add_argument(
"--decoder",
type=str,
required=True,
help="Path to the decoder",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens",
)
parser.add_argument(
"--language",
type=str,
help="""The actual spoken language in the audio.
Example values, en, de, zh, jp, fr.
If None, we will detect the language using the first 30s of the
input audio
""",
)
parser.add_argument(
"--task",
choices=["transcribe", "translate"],
type=str,
default="transcribe",
help="Valid values are: transcribe, translate",
)
parser.add_argument(
"sound_file",
type=str,
help="Path to the test wave",
)
return parser.parse_args()
class OnnxModel:
def __init__(
self,
encoder: str,
decoder: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_encoder(encoder)
self.init_decoder(decoder)
def init_encoder(self, encoder: str):
self.encoder = ort.InferenceSession(
encoder,
sess_options=self.session_opts,
)
meta = self.encoder.get_modelmeta().custom_metadata_map
self.n_text_layer = int(meta["n_text_layer"])
self.n_text_ctx = int(meta["n_text_ctx"])
self.n_text_state = int(meta["n_text_state"])
self.sot = int(meta["sot"])
self.eot = int(meta["eot"])
self.translate = int(meta["translate"])
self.transcribe = int(meta["transcribe"])
self.no_timestamps = int(meta["no_timestamps"])
self.no_speech = int(meta["no_speech"])
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
map(int, meta["all_language_tokens"].split(","))
)
self.all_language_codes = meta["all_language_codes"].split(",")
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
self.is_multilingual = int(meta["is_multilingual"]) == 1
def init_decoder(self, decoder: str):
self.decoder = ort.InferenceSession(
decoder,
sess_options=self.session_opts,
)
def run_encoder(
self,
mel: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
n_layer_cross_k, n_layer_cross_v = self.encoder.run(
[
self.encoder.get_outputs()[0].name,
self.encoder.get_outputs()[1].name,
],
{
self.encoder.get_inputs()[0].name: mel.numpy(),
},
)
return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)
def run_decoder(
self,
tokens: torch.Tensor,
n_layer_self_k_cache: torch.Tensor,
n_layer_self_v_cache: torch.Tensor,
n_layer_cross_k: torch.Tensor,
n_layer_cross_v: torch.Tensor,
offset: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
[
self.decoder.get_outputs()[0].name,
self.decoder.get_outputs()[1].name,
self.decoder.get_outputs()[2].name,
],
{
self.decoder.get_inputs()[0].name: tokens.numpy(),
self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
self.decoder.get_inputs()[5].name: offset.numpy(),
},
)
return (
torch.from_numpy(logits),
torch.from_numpy(out_n_layer_self_k_cache),
torch.from_numpy(out_n_layer_self_v_cache),
)
def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = 1
n_layer_self_k_cache = torch.zeros(
self.n_text_layer,
batch_size,
self.n_text_ctx,
self.n_text_state,
)
n_layer_self_v_cache = torch.zeros(
self.n_text_layer,
batch_size,
self.n_text_ctx,
self.n_text_state,
)
return n_layer_self_k_cache, n_layer_self_v_cache
def suppress_tokens(self, logits, is_initial: bool) -> None:
# suppress blank
if is_initial:
logits[self.eot] = float("-inf")
logits[self.blank] = float("-inf")
# suppress <|notimestamps|>
logits[self.no_timestamps] = float("-inf")
logits[self.sot] = float("-inf")
logits[self.no_speech] = float("-inf")
# logits is changed in-place
logits[self.translate] = float("-inf")
def detect_language(
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
) -> int:
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
logits = logits.reshape(-1)
mask = torch.ones(logits.shape[0], dtype=torch.int64)
mask[self.all_language_tokens] = 0
logits[mask] = float("-inf")
lang_id = logits.argmax().item()
print("detected language: ", self.id2lang[lang_id])
return lang_id
def load_tokens(filename):
tokens = dict()
with open(filename, "r") as f:
for line in f:
t, i = line.split()
tokens[int(i)] = t
return tokens
def compute_features(filename: str) -> torch.Tensor:
"""
Args:
filename:
Path to an audio file.
Returns:
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
"""
wave, sample_rate = torchaudio.load(filename)
audio = wave[0].contiguous() # only use the first channel
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio, orig_freq=sample_rate, new_freq=16000
)
features = []
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
online_whisper_fbank.accept_waveform(16000, audio.numpy())
online_whisper_fbank.input_finished()
for i in range(online_whisper_fbank.num_frames_ready):
f = online_whisper_fbank.get_frame(i)
f = torch.from_numpy(f)
features.append(f)
features = torch.stack(features)
log_spec = torch.clamp(features, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
mel = (log_spec + 4.0) / 4.0
# mel (T, 80)
# We pad 50 frames at the end so that it is able to detect eot
# You can use another value instead of 50.
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
# Note that if it throws for a multilingual model,
# please use a larger value, say 300
target = 3000
if mel.shape[0] > target:
mel = mel[:target]
# We don't need to pad it to 30 seconds now!
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
mel = mel.t().unsqueeze(0)
return mel
def main():
args = get_args()
mel = compute_features(args.sound_file)
model = OnnxModel(args.encoder, args.decoder)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
if args.language is not None:
if model.is_multilingual is False and args.language != "en":
print(f"This model supports only English. Given: {args.language}")
return
if args.language not in model.lang2id:
print(f"Invalid language: {args.language}")
print(f"Valid values are: {list(model.lang2id.keys())}")
return
# [sot, lang, task, notimestamps]
model.sot_sequence[1] = model.lang2id[args.language]
elif model.is_multilingual is True:
print("detecting language")
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
model.sot_sequence[1] = lang
if args.task is not None:
if model.is_multilingual is False and args.task != "transcribe":
print("This model supports only English. Please use --task=transcribe")
return
assert args.task in ["transcribe", "translate"], args.task
if args.task == "translate":
model.sot_sequence[2] = model.translate
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
offset += len(model.sot_sequence)
# logits.shape (batch_size, tokens.shape[1], vocab_size)
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=True)
# logits = logits.softmax(dim=-1)
# for greedy search, we don't need to compute softmax or log_softmax
max_token_id = logits.argmax(dim=-1)
results = []
for i in range(model.n_text_ctx):
if max_token_id == model.eot:
break
results.append(max_token_id.item())
tokens = torch.tensor([[results[-1]]])
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
offset += 1
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=False)
max_token_id = logits.argmax(dim=-1)
token_table = load_tokens(args.tokens)
s = b""
for i in results:
if i in token_table:
s += base64.b64decode(token_table[i])
print(s.decode().strip())
if __name__ == "__main__":
main()