252 lines
7.5 KiB
Python
Executable File
252 lines
7.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
"""
|
|
Please first run ./export-onnx.py
|
|
before you run this script
|
|
"""
|
|
import base64
|
|
from typing import Tuple
|
|
|
|
import kaldi_native_fbank as knf
|
|
import onnxruntime as ort
|
|
import torch
|
|
|
|
import whisper
|
|
import argparse
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--encoder",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the encoder",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--decoder",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the decoder",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--tokens",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the tokens",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"sound_file",
|
|
type=str,
|
|
help="Path to the test wave",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
class OnnxModel:
|
|
def __init__(
|
|
self,
|
|
encoder: str,
|
|
decoder: str,
|
|
):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 4
|
|
|
|
self.session_opts = session_opts
|
|
|
|
self.init_encoder(encoder)
|
|
self.init_decoder(decoder)
|
|
|
|
def init_encoder(self, encoder: str):
|
|
self.encoder = ort.InferenceSession(
|
|
encoder,
|
|
sess_options=self.session_opts,
|
|
)
|
|
|
|
meta = self.encoder.get_modelmeta().custom_metadata_map
|
|
self.n_text_layer = int(meta["n_text_layer"])
|
|
self.n_text_ctx = int(meta["n_text_ctx"])
|
|
self.n_text_state = int(meta["n_text_state"])
|
|
self.sot = int(meta["sot"])
|
|
self.eot = int(meta["eot"])
|
|
self.translate = int(meta["translate"])
|
|
self.no_timestamps = int(meta["no_timestamps"])
|
|
self.no_speech = int(meta["no_speech"])
|
|
self.blank = int(meta["blank_id"])
|
|
|
|
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
|
|
|
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
|
|
|
def init_decoder(self, decoder: str):
|
|
self.decoder = ort.InferenceSession(
|
|
decoder,
|
|
sess_options=self.session_opts,
|
|
)
|
|
|
|
def run_encoder(
|
|
self,
|
|
mel: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
n_layer_cross_k, n_layer_cross_v = self.encoder.run(
|
|
[
|
|
self.encoder.get_outputs()[0].name,
|
|
self.encoder.get_outputs()[1].name,
|
|
],
|
|
{
|
|
self.encoder.get_inputs()[0].name: mel.numpy(),
|
|
},
|
|
)
|
|
return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)
|
|
|
|
def run_decoder(
|
|
self,
|
|
tokens: torch.Tensor,
|
|
n_layer_self_k_cache: torch.Tensor,
|
|
n_layer_self_v_cache: torch.Tensor,
|
|
n_layer_cross_k: torch.Tensor,
|
|
n_layer_cross_v: torch.Tensor,
|
|
offset: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
|
|
[
|
|
self.decoder.get_outputs()[0].name,
|
|
self.decoder.get_outputs()[1].name,
|
|
self.decoder.get_outputs()[2].name,
|
|
],
|
|
{
|
|
self.decoder.get_inputs()[0].name: tokens.numpy(),
|
|
self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
|
|
self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
|
|
self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
|
|
self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
|
|
self.decoder.get_inputs()[5].name: offset.numpy(),
|
|
},
|
|
)
|
|
return (
|
|
torch.from_numpy(logits),
|
|
torch.from_numpy(out_n_layer_self_k_cache),
|
|
torch.from_numpy(out_n_layer_self_v_cache),
|
|
)
|
|
|
|
def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
batch_size = 1
|
|
n_layer_self_k_cache = torch.zeros(
|
|
self.n_text_layer,
|
|
batch_size,
|
|
self.n_text_ctx,
|
|
self.n_text_state,
|
|
)
|
|
n_layer_self_v_cache = torch.zeros(
|
|
self.n_text_layer,
|
|
batch_size,
|
|
self.n_text_ctx,
|
|
self.n_text_state,
|
|
)
|
|
return n_layer_self_k_cache, n_layer_self_v_cache
|
|
|
|
def suppress_tokens(self, logits, is_initial: bool) -> None:
|
|
# suppress blank
|
|
if is_initial:
|
|
logits[self.eot] = float("-inf")
|
|
logits[self.blank] = float("-inf")
|
|
|
|
# suppress <|notimestamps|>
|
|
logits[self.no_timestamps] = float("-inf")
|
|
|
|
logits[self.sot] = float("-inf")
|
|
logits[self.no_speech] = float("-inf")
|
|
|
|
# logits is changed in-place
|
|
logits[self.translate] = float("-inf")
|
|
|
|
|
|
def load_tokens(filename):
|
|
tokens = dict()
|
|
with open(filename, "r") as f:
|
|
for line in f:
|
|
t, i = line.split()
|
|
tokens[int(i)] = t
|
|
return tokens
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
encoder = args.encoder
|
|
decoder = args.decoder
|
|
|
|
audio = whisper.load_audio(args.sound_file)
|
|
|
|
features = []
|
|
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
|
online_whisper_fbank.accept_waveform(16000, audio)
|
|
online_whisper_fbank.input_finished()
|
|
for i in range(online_whisper_fbank.num_frames_ready):
|
|
f = online_whisper_fbank.get_frame(i)
|
|
f = torch.from_numpy(f)
|
|
features.append(f)
|
|
|
|
features = torch.stack(features)
|
|
|
|
log_spec = torch.clamp(features, min=1e-10).log10()
|
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
mel = (log_spec + 4.0) / 4.0
|
|
target = 3000
|
|
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
|
mel = mel.t().unsqueeze(0)
|
|
|
|
model = OnnxModel(encoder, decoder)
|
|
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
|
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
|
|
|
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
|
offset = torch.zeros(1, dtype=torch.int64)
|
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
|
tokens=tokens,
|
|
n_layer_self_k_cache=n_layer_self_k_cache,
|
|
n_layer_self_v_cache=n_layer_self_v_cache,
|
|
n_layer_cross_k=n_layer_cross_k,
|
|
n_layer_cross_v=n_layer_cross_v,
|
|
offset=offset,
|
|
)
|
|
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
|
logits = logits[0, -1]
|
|
model.suppress_tokens(logits, is_initial=True)
|
|
# logits = logits.softmax(dim=-1)
|
|
# for greedy search, we don't need to compute softmax or log_softmax
|
|
max_token_id = logits.argmax(dim=-1)
|
|
results = []
|
|
for i in range(model.n_text_ctx):
|
|
if max_token_id == model.eot:
|
|
break
|
|
results.append(max_token_id.item())
|
|
tokens = torch.tensor([[results[-1]]])
|
|
offset += 1
|
|
|
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
|
tokens=tokens,
|
|
n_layer_self_k_cache=n_layer_self_k_cache,
|
|
n_layer_self_v_cache=n_layer_self_v_cache,
|
|
n_layer_cross_k=n_layer_cross_k,
|
|
n_layer_cross_v=n_layer_cross_v,
|
|
offset=offset,
|
|
)
|
|
logits = logits[0, -1]
|
|
model.suppress_tokens(logits, is_initial=False)
|
|
max_token_id = logits.argmax(dim=-1)
|
|
token_table = load_tokens(args.tokens)
|
|
s = b""
|
|
for i in results:
|
|
if i in token_table:
|
|
s += base64.b64decode(token_table[i])
|
|
|
|
print(s.decode().strip())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|