This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/scripts/wenet/test-onnx-streaming.py

175 lines
5.3 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import kaldi_native_fbank as knf
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
self.left_chunks = int(meta["left_chunks"])
self.num_blocks = int(meta["num_blocks"])
self.chunk_size = int(meta["chunk_size"])
self.head = int(meta["head"])
self.output_size = int(meta["output_size"])
self.cnn_module_kernel = int(meta["cnn_module_kernel"])
self.right_context = int(meta["right_context"])
self.subsampling_factor = int(meta["subsampling_factor"])
self._init_cache()
def _init_cache(self):
required_cache_size = self.chunk_size * self.left_chunks
self.attn_cache = torch.zeros(
self.num_blocks,
self.head,
required_cache_size,
self.output_size // self.head * 2,
dtype=torch.float32,
).numpy()
self.conv_cache = torch.zeros(
self.num_blocks,
1,
self.output_size,
self.cnn_module_kernel - 1,
dtype=torch.float32,
).numpy()
self.offset = torch.tensor([required_cache_size], dtype=torch.int64).numpy()
self.required_cache_size = torch.tensor(
[self.chunk_size * self.left_chunks], dtype=torch.int64
).numpy()
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (T, C)
Returns:
Return a 2-D tensor of shape (T, C) containing log_probs.
"""
attn_mask = torch.ones(
1, 1, int(self.required_cache_size + self.chunk_size), dtype=torch.bool
)
chunk_idx = self.offset // self.chunk_size - self.left_chunks
if chunk_idx < self.left_chunks:
attn_mask[
:, :, : int(self.required_cache_size - chunk_idx * self.chunk_size)
] = False
log_probs, new_attn_cache, new_conv_cache = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
self.model.get_outputs()[2].name,
],
{
self.model.get_inputs()[0].name: x.unsqueeze(0).numpy(),
self.model.get_inputs()[1].name: self.offset,
self.model.get_inputs()[2].name: self.required_cache_size,
self.model.get_inputs()[3].name: self.attn_cache,
self.model.get_inputs()[4].name: self.conv_cache,
self.model.get_inputs()[5].name: attn_mask.numpy(),
},
)
self.attn_cache = new_attn_cache
self.conv_cache = new_conv_cache
log_probs = torch.from_numpy(log_probs)
self.offset += log_probs.shape[1]
return log_probs.squeeze(0)
def get_features(test_wav_filename):
wave, sample_rate = torchaudio.load(test_wav_filename)
audio = wave[0].contiguous() # only use the first channel
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio, orig_freq=sample_rate, new_freq=16000
)
audio *= 372768
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 80
opts.frame_opts.snip_edges = False
opts.mel_opts.debug_mel = False
fbank = knf.OnlineFbank(opts)
fbank.accept_waveform(16000, audio.numpy())
frames = []
for i in range(fbank.num_frames_ready):
frames.append(torch.from_numpy(fbank.get_frame(i)))
frames = torch.stack(frames)
return frames
def main():
model_filename = "./model-streaming.onnx"
model = OnnxModel(model_filename)
filename = "./0.wav"
x = get_features(filename)
padding = torch.zeros(int(16000 * 0.5), 80)
x = torch.cat([x, padding], dim=0)
chunk_length = (
(model.chunk_size - 1) * model.subsampling_factor + model.right_context + 1
)
chunk_length = int(chunk_length)
chunk_shift = int(model.required_cache_size)
print(chunk_length, chunk_shift)
num_frames = x.shape[0]
n = (num_frames - chunk_length) // chunk_shift + 1
tokens = []
for i in range(n):
start = i * chunk_shift
end = start + chunk_length
frames = x[start:end, :]
log_probs = model(frames)
indexes = log_probs.argmax(dim=1)
indexes = torch.unique_consecutive(indexes)
indexes = indexes[indexes != 0].tolist()
if indexes:
tokens.extend(indexes)
id2word = dict()
with open("./units.txt", encoding="utf-8") as f:
for line in f:
word, idx = line.strip().split()
id2word[int(idx)] = word
text = "".join([id2word[i] for i in tokens])
print(text)
if __name__ == "__main__":
main()