Begin to support https://github.com/usefulsensors/moonshine (#1470)
This commit is contained in:
274
scripts/moonshine/test.py
Executable file
274
scripts/moonshine/test.py
Executable file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import datetime as dt
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def display(sess, name):
|
||||
print(f"=========={name} Input==========")
|
||||
for i in sess.get_inputs():
|
||||
print(i)
|
||||
print(f"=========={name} Output==========")
|
||||
for i in sess.get_outputs():
|
||||
print(i)
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
preprocess: str,
|
||||
encode: str,
|
||||
uncached_decode: str,
|
||||
cached_decode: str,
|
||||
):
|
||||
self.init_preprocess(preprocess)
|
||||
display(self.preprocess, "preprocess")
|
||||
|
||||
self.init_encode(encode)
|
||||
display(self.encode, "encode")
|
||||
|
||||
self.init_uncached_decode(uncached_decode)
|
||||
display(self.uncached_decode, "uncached_decode")
|
||||
|
||||
self.init_cached_decode(cached_decode)
|
||||
display(self.cached_decode, "cached_decode")
|
||||
|
||||
def init_preprocess(self, preprocess):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.preprocess = ort.InferenceSession(
|
||||
preprocess,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_encode(self, encode):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.encode = ort.InferenceSession(
|
||||
encode,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_uncached_decode(self, uncached_decode):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.uncached_decode = ort.InferenceSession(
|
||||
uncached_decode,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_cached_decode(self, cached_decode):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.cached_decode = ort.InferenceSession(
|
||||
cached_decode,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_preprocess(self, audio):
|
||||
"""
|
||||
Args:
|
||||
audio: (batch_size, num_samples), float32
|
||||
Returns:
|
||||
A tensor of shape (batch_size, T, dim), float32
|
||||
"""
|
||||
return self.preprocess.run(
|
||||
[
|
||||
self.preprocess.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.preprocess.get_inputs()[0].name: audio,
|
||||
},
|
||||
)[0]
|
||||
|
||||
def run_encode(self, features):
|
||||
"""
|
||||
Args:
|
||||
features: (batch_size, T, dim)
|
||||
Returns:
|
||||
A tensor of shape (batch_size, T, dim)
|
||||
"""
|
||||
features_len = np.array([features.shape[1]], dtype=np.int32)
|
||||
|
||||
return self.encode.run(
|
||||
[
|
||||
self.encode.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.encode.get_inputs()[0].name: features,
|
||||
self.encode.get_inputs()[1].name: features_len,
|
||||
},
|
||||
)[0]
|
||||
|
||||
def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
|
||||
"""
|
||||
Args:
|
||||
token: The current token
|
||||
token_len: Number of predicted tokens so far
|
||||
encoder_out: A tensor fo shape (batch_size, T, dim)
|
||||
Returns:
|
||||
A a tuple:
|
||||
- a tensor of shape (batch_size, 1, dim)
|
||||
- a list of states
|
||||
"""
|
||||
token_tensor = np.array([[token]], dtype=np.int32)
|
||||
token_len_tensor = np.array([token_len], dtype=np.int32)
|
||||
|
||||
num_outs = len(self.uncached_decode.get_outputs())
|
||||
out_names = [
|
||||
self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
|
||||
]
|
||||
|
||||
out = self.uncached_decode.run(
|
||||
out_names,
|
||||
{
|
||||
self.uncached_decode.get_inputs()[0].name: token_tensor,
|
||||
self.uncached_decode.get_inputs()[1].name: encoder_out,
|
||||
self.uncached_decode.get_inputs()[2].name: token_len_tensor,
|
||||
},
|
||||
)
|
||||
|
||||
logits = out[0]
|
||||
states = out[1:]
|
||||
|
||||
return logits, states
|
||||
|
||||
def run_cached_decode(
|
||||
self, token: int, token_len: int, encoder_out: np.ndarray, states
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
token: The current token
|
||||
token_len: Number of predicted tokens so far
|
||||
encoder_out: A tensor of shape (batch_size, T, dim)
|
||||
states: previous states
|
||||
Returns:
|
||||
A a tuple:
|
||||
- a tensor of shape (batch_size, 1, dim)
|
||||
- a list of states
|
||||
"""
|
||||
token_tensor = np.array([[token]], dtype=np.int32)
|
||||
token_len_tensor = np.array([token_len], dtype=np.int32)
|
||||
|
||||
num_outs = len(self.cached_decode.get_outputs())
|
||||
out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
|
||||
|
||||
states_inputs = {}
|
||||
for i in range(3, len(self.cached_decode.get_inputs())):
|
||||
name = self.cached_decode.get_inputs()[i].name
|
||||
states_inputs[name] = states[i - 3]
|
||||
|
||||
out = self.cached_decode.run(
|
||||
out_names,
|
||||
{
|
||||
self.cached_decode.get_inputs()[0].name: token_tensor,
|
||||
self.cached_decode.get_inputs()[1].name: encoder_out,
|
||||
self.cached_decode.get_inputs()[2].name: token_len_tensor,
|
||||
**states_inputs,
|
||||
},
|
||||
)
|
||||
|
||||
logits = out[0]
|
||||
states = out[1:]
|
||||
|
||||
return logits, states
|
||||
|
||||
|
||||
def main():
|
||||
wave = "./1.wav"
|
||||
id2token = dict()
|
||||
token2id = dict()
|
||||
with open("./tokens.txt", encoding="utf-8") as f:
|
||||
for k, line in enumerate(f):
|
||||
t, idx = line.split("\t")
|
||||
id2token[int(idx)] = t
|
||||
token2id[t] = int(idx)
|
||||
|
||||
model = OnnxModel(
|
||||
preprocess="./preprocess.onnx",
|
||||
encode="./encode.int8.onnx",
|
||||
uncached_decode="./uncached_decode.int8.onnx",
|
||||
cached_decode="./cached_decode.int8.onnx",
|
||||
)
|
||||
|
||||
audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
if sample_rate != 16000:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sample_rate,
|
||||
target_sr=16000,
|
||||
)
|
||||
sample_rate = 16000
|
||||
audio = audio[None] # (1, num_samples)
|
||||
print("audio.shape", audio.shape) # (1, 159414)
|
||||
|
||||
start_t = dt.datetime.now()
|
||||
|
||||
features = model.run_preprocess(audio) # (1, 413, 288)
|
||||
print("features", features.shape)
|
||||
|
||||
sos = token2id["<s>"]
|
||||
eos = token2id["</s>"]
|
||||
|
||||
tokens = [sos]
|
||||
|
||||
encoder_out = model.run_encode(features)
|
||||
print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
|
||||
|
||||
logits, states = model.run_uncached_decode(
|
||||
token=tokens[-1],
|
||||
token_len=len(tokens),
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
|
||||
print("logits.shape", logits.shape) # (1, 1, 32768)
|
||||
print("len(states)", len(states)) # 24
|
||||
|
||||
max_len = int((audio.shape[-1] / 16000) * 6)
|
||||
|
||||
for i in range(max_len):
|
||||
token = logits.squeeze().argmax()
|
||||
if token == eos:
|
||||
break
|
||||
tokens.append(token)
|
||||
|
||||
logits, states = model.run_cached_decode(
|
||||
token=tokens[-1],
|
||||
token_len=len(tokens),
|
||||
encoder_out=encoder_out,
|
||||
states=states,
|
||||
)
|
||||
|
||||
tokens = tokens[1:] # remove sos
|
||||
words = [id2token[i] for i in tokens]
|
||||
underline = "▁"
|
||||
# underline = b"\xe2\x96\x81".decode()
|
||||
text = "".join(words).replace(underline, " ").strip()
|
||||
|
||||
end_t = dt.datetime.now()
|
||||
t = (end_t - start_t).total_seconds()
|
||||
rtf = t * 16000 / audio.shape[-1]
|
||||
|
||||
print(text)
|
||||
print("RTF:", rtf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user