275 lines
7.8 KiB
Python
Executable File
275 lines
7.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
import datetime as dt
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import soundfile as sf
|
|
|
|
|
|
def display(sess, name):
|
|
print(f"=========={name} Input==========")
|
|
for i in sess.get_inputs():
|
|
print(i)
|
|
print(f"=========={name} Output==========")
|
|
for i in sess.get_outputs():
|
|
print(i)
|
|
|
|
|
|
class OnnxModel:
|
|
def __init__(
|
|
self,
|
|
preprocess: str,
|
|
encode: str,
|
|
uncached_decode: str,
|
|
cached_decode: str,
|
|
):
|
|
self.init_preprocess(preprocess)
|
|
display(self.preprocess, "preprocess")
|
|
|
|
self.init_encode(encode)
|
|
display(self.encode, "encode")
|
|
|
|
self.init_uncached_decode(uncached_decode)
|
|
display(self.uncached_decode, "uncached_decode")
|
|
|
|
self.init_cached_decode(cached_decode)
|
|
display(self.cached_decode, "cached_decode")
|
|
|
|
def init_preprocess(self, preprocess):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
self.preprocess = ort.InferenceSession(
|
|
preprocess,
|
|
sess_options=session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
def init_encode(self, encode):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
self.encode = ort.InferenceSession(
|
|
encode,
|
|
sess_options=session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
def init_uncached_decode(self, uncached_decode):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
self.uncached_decode = ort.InferenceSession(
|
|
uncached_decode,
|
|
sess_options=session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
def init_cached_decode(self, cached_decode):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
self.cached_decode = ort.InferenceSession(
|
|
cached_decode,
|
|
sess_options=session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
def run_preprocess(self, audio):
|
|
"""
|
|
Args:
|
|
audio: (batch_size, num_samples), float32
|
|
Returns:
|
|
A tensor of shape (batch_size, T, dim), float32
|
|
"""
|
|
return self.preprocess.run(
|
|
[
|
|
self.preprocess.get_outputs()[0].name,
|
|
],
|
|
{
|
|
self.preprocess.get_inputs()[0].name: audio,
|
|
},
|
|
)[0]
|
|
|
|
def run_encode(self, features):
|
|
"""
|
|
Args:
|
|
features: (batch_size, T, dim)
|
|
Returns:
|
|
A tensor of shape (batch_size, T, dim)
|
|
"""
|
|
features_len = np.array([features.shape[1]], dtype=np.int32)
|
|
|
|
return self.encode.run(
|
|
[
|
|
self.encode.get_outputs()[0].name,
|
|
],
|
|
{
|
|
self.encode.get_inputs()[0].name: features,
|
|
self.encode.get_inputs()[1].name: features_len,
|
|
},
|
|
)[0]
|
|
|
|
def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
|
|
"""
|
|
Args:
|
|
token: The current token
|
|
token_len: Number of predicted tokens so far
|
|
encoder_out: A tensor fo shape (batch_size, T, dim)
|
|
Returns:
|
|
A a tuple:
|
|
- a tensor of shape (batch_size, 1, dim)
|
|
- a list of states
|
|
"""
|
|
token_tensor = np.array([[token]], dtype=np.int32)
|
|
token_len_tensor = np.array([token_len], dtype=np.int32)
|
|
|
|
num_outs = len(self.uncached_decode.get_outputs())
|
|
out_names = [
|
|
self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
|
|
]
|
|
|
|
out = self.uncached_decode.run(
|
|
out_names,
|
|
{
|
|
self.uncached_decode.get_inputs()[0].name: token_tensor,
|
|
self.uncached_decode.get_inputs()[1].name: encoder_out,
|
|
self.uncached_decode.get_inputs()[2].name: token_len_tensor,
|
|
},
|
|
)
|
|
|
|
logits = out[0]
|
|
states = out[1:]
|
|
|
|
return logits, states
|
|
|
|
def run_cached_decode(
|
|
self, token: int, token_len: int, encoder_out: np.ndarray, states
|
|
):
|
|
"""
|
|
Args:
|
|
token: The current token
|
|
token_len: Number of predicted tokens so far
|
|
encoder_out: A tensor of shape (batch_size, T, dim)
|
|
states: previous states
|
|
Returns:
|
|
A a tuple:
|
|
- a tensor of shape (batch_size, 1, dim)
|
|
- a list of states
|
|
"""
|
|
token_tensor = np.array([[token]], dtype=np.int32)
|
|
token_len_tensor = np.array([token_len], dtype=np.int32)
|
|
|
|
num_outs = len(self.cached_decode.get_outputs())
|
|
out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
|
|
|
|
states_inputs = {}
|
|
for i in range(3, len(self.cached_decode.get_inputs())):
|
|
name = self.cached_decode.get_inputs()[i].name
|
|
states_inputs[name] = states[i - 3]
|
|
|
|
out = self.cached_decode.run(
|
|
out_names,
|
|
{
|
|
self.cached_decode.get_inputs()[0].name: token_tensor,
|
|
self.cached_decode.get_inputs()[1].name: encoder_out,
|
|
self.cached_decode.get_inputs()[2].name: token_len_tensor,
|
|
**states_inputs,
|
|
},
|
|
)
|
|
|
|
logits = out[0]
|
|
states = out[1:]
|
|
|
|
return logits, states
|
|
|
|
|
|
def main():
|
|
wave = "./1.wav"
|
|
id2token = dict()
|
|
token2id = dict()
|
|
with open("./tokens.txt", encoding="utf-8") as f:
|
|
for k, line in enumerate(f):
|
|
t, idx = line.split("\t")
|
|
id2token[int(idx)] = t
|
|
token2id[t] = int(idx)
|
|
|
|
model = OnnxModel(
|
|
preprocess="./preprocess.onnx",
|
|
encode="./encode.int8.onnx",
|
|
uncached_decode="./uncached_decode.int8.onnx",
|
|
cached_decode="./cached_decode.int8.onnx",
|
|
)
|
|
|
|
audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
|
|
audio = audio[:, 0] # only use the first channel
|
|
if sample_rate != 16000:
|
|
audio = librosa.resample(
|
|
audio,
|
|
orig_sr=sample_rate,
|
|
target_sr=16000,
|
|
)
|
|
sample_rate = 16000
|
|
audio = audio[None] # (1, num_samples)
|
|
print("audio.shape", audio.shape) # (1, 159414)
|
|
|
|
start_t = dt.datetime.now()
|
|
|
|
features = model.run_preprocess(audio) # (1, 413, 288)
|
|
print("features", features.shape)
|
|
|
|
sos = token2id["<s>"]
|
|
eos = token2id["</s>"]
|
|
|
|
tokens = [sos]
|
|
|
|
encoder_out = model.run_encode(features)
|
|
print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
|
|
|
|
logits, states = model.run_uncached_decode(
|
|
token=tokens[-1],
|
|
token_len=len(tokens),
|
|
encoder_out=encoder_out,
|
|
)
|
|
|
|
print("logits.shape", logits.shape) # (1, 1, 32768)
|
|
print("len(states)", len(states)) # 24
|
|
|
|
max_len = int((audio.shape[-1] / 16000) * 6)
|
|
|
|
for i in range(max_len):
|
|
token = logits.squeeze().argmax()
|
|
if token == eos:
|
|
break
|
|
tokens.append(token)
|
|
|
|
logits, states = model.run_cached_decode(
|
|
token=tokens[-1],
|
|
token_len=len(tokens),
|
|
encoder_out=encoder_out,
|
|
states=states,
|
|
)
|
|
|
|
tokens = tokens[1:] # remove sos
|
|
words = [id2token[i] for i in tokens]
|
|
underline = "▁"
|
|
# underline = b"\xe2\x96\x81".decode()
|
|
text = "".join(words).replace(underline, " ").strip()
|
|
|
|
end_t = dt.datetime.now()
|
|
t = (end_t - start_t).total_seconds()
|
|
rtf = t * 16000 / audio.shape[-1]
|
|
|
|
print(text)
|
|
print("RTF:", rtf)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|