This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/scripts/moonshine/test.py

275 lines
7.8 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import datetime as dt
import librosa
import numpy as np
import onnxruntime as ort
import soundfile as sf
def display(sess, name):
print(f"=========={name} Input==========")
for i in sess.get_inputs():
print(i)
print(f"=========={name} Output==========")
for i in sess.get_outputs():
print(i)
class OnnxModel:
def __init__(
self,
preprocess: str,
encode: str,
uncached_decode: str,
cached_decode: str,
):
self.init_preprocess(preprocess)
display(self.preprocess, "preprocess")
self.init_encode(encode)
display(self.encode, "encode")
self.init_uncached_decode(uncached_decode)
display(self.uncached_decode, "uncached_decode")
self.init_cached_decode(cached_decode)
display(self.cached_decode, "cached_decode")
def init_preprocess(self, preprocess):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.preprocess = ort.InferenceSession(
preprocess,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_encode(self, encode):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.encode = ort.InferenceSession(
encode,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_uncached_decode(self, uncached_decode):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.uncached_decode = ort.InferenceSession(
uncached_decode,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_cached_decode(self, cached_decode):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.cached_decode = ort.InferenceSession(
cached_decode,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def run_preprocess(self, audio):
"""
Args:
audio: (batch_size, num_samples), float32
Returns:
A tensor of shape (batch_size, T, dim), float32
"""
return self.preprocess.run(
[
self.preprocess.get_outputs()[0].name,
],
{
self.preprocess.get_inputs()[0].name: audio,
},
)[0]
def run_encode(self, features):
"""
Args:
features: (batch_size, T, dim)
Returns:
A tensor of shape (batch_size, T, dim)
"""
features_len = np.array([features.shape[1]], dtype=np.int32)
return self.encode.run(
[
self.encode.get_outputs()[0].name,
],
{
self.encode.get_inputs()[0].name: features,
self.encode.get_inputs()[1].name: features_len,
},
)[0]
def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
"""
Args:
token: The current token
token_len: Number of predicted tokens so far
encoder_out: A tensor fo shape (batch_size, T, dim)
Returns:
A a tuple:
- a tensor of shape (batch_size, 1, dim)
- a list of states
"""
token_tensor = np.array([[token]], dtype=np.int32)
token_len_tensor = np.array([token_len], dtype=np.int32)
num_outs = len(self.uncached_decode.get_outputs())
out_names = [
self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
]
out = self.uncached_decode.run(
out_names,
{
self.uncached_decode.get_inputs()[0].name: token_tensor,
self.uncached_decode.get_inputs()[1].name: encoder_out,
self.uncached_decode.get_inputs()[2].name: token_len_tensor,
},
)
logits = out[0]
states = out[1:]
return logits, states
def run_cached_decode(
self, token: int, token_len: int, encoder_out: np.ndarray, states
):
"""
Args:
token: The current token
token_len: Number of predicted tokens so far
encoder_out: A tensor of shape (batch_size, T, dim)
states: previous states
Returns:
A a tuple:
- a tensor of shape (batch_size, 1, dim)
- a list of states
"""
token_tensor = np.array([[token]], dtype=np.int32)
token_len_tensor = np.array([token_len], dtype=np.int32)
num_outs = len(self.cached_decode.get_outputs())
out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
states_inputs = {}
for i in range(3, len(self.cached_decode.get_inputs())):
name = self.cached_decode.get_inputs()[i].name
states_inputs[name] = states[i - 3]
out = self.cached_decode.run(
out_names,
{
self.cached_decode.get_inputs()[0].name: token_tensor,
self.cached_decode.get_inputs()[1].name: encoder_out,
self.cached_decode.get_inputs()[2].name: token_len_tensor,
**states_inputs,
},
)
logits = out[0]
states = out[1:]
return logits, states
def main():
wave = "./1.wav"
id2token = dict()
token2id = dict()
with open("./tokens.txt", encoding="utf-8") as f:
for k, line in enumerate(f):
t, idx = line.split("\t")
id2token[int(idx)] = t
token2id[t] = int(idx)
model = OnnxModel(
preprocess="./preprocess.onnx",
encode="./encode.int8.onnx",
uncached_decode="./uncached_decode.int8.onnx",
cached_decode="./cached_decode.int8.onnx",
)
audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != 16000:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=16000,
)
sample_rate = 16000
audio = audio[None] # (1, num_samples)
print("audio.shape", audio.shape) # (1, 159414)
start_t = dt.datetime.now()
features = model.run_preprocess(audio) # (1, 413, 288)
print("features", features.shape)
sos = token2id["<s>"]
eos = token2id["</s>"]
tokens = [sos]
encoder_out = model.run_encode(features)
print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
logits, states = model.run_uncached_decode(
token=tokens[-1],
token_len=len(tokens),
encoder_out=encoder_out,
)
print("logits.shape", logits.shape) # (1, 1, 32768)
print("len(states)", len(states)) # 24
max_len = int((audio.shape[-1] / 16000) * 6)
for i in range(max_len):
token = logits.squeeze().argmax()
if token == eos:
break
tokens.append(token)
logits, states = model.run_cached_decode(
token=tokens[-1],
token_len=len(tokens),
encoder_out=encoder_out,
states=states,
)
tokens = tokens[1:] # remove sos
words = [id2token[i] for i in tokens]
underline = ""
# underline = b"\xe2\x96\x81".decode()
text = "".join(words).replace(underline, " ").strip()
end_t = dt.datetime.now()
t = (end_t - start_t).total_seconds()
rtf = t * 16000 / audio.shape[-1]
print(text)
print("RTF:", rtf)
if __name__ == "__main__":
main()