Add scripts to export ASR models from wenet to ONNX (#425)
See https://user-images.githubusercontent.com/5284924/282995968-f6d39118-8008-4ce7-9d7c-d1d6387ac183.png
This commit is contained in:
104
scripts/wenet/test-onnx.py
Executable file
104
scripts/wenet/test-onnx.py
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T, C) containing log_probs.
|
||||
"""
|
||||
log_probs, log_probs_lens = self.model.run(
|
||||
[self.model.get_outputs()[0].name, self.model.get_outputs()[1].name],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(log_probs), torch.from_numpy(log_probs_lens)
|
||||
|
||||
|
||||
def get_features(test_wav_filename):
|
||||
wave, sample_rate = torchaudio.load(test_wav_filename)
|
||||
audio = wave[0].contiguous() # only use the first channel
|
||||
if sample_rate != 16000:
|
||||
audio = torchaudio.functional.resample(
|
||||
audio, orig_freq=sample_rate, new_freq=16000
|
||||
)
|
||||
audio *= 372768
|
||||
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.mel_opts.num_bins = 80
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.mel_opts.debug_mel = False
|
||||
|
||||
fbank = knf.OnlineFbank(opts)
|
||||
fbank.accept_waveform(16000, audio.numpy())
|
||||
frames = []
|
||||
for i in range(fbank.num_frames_ready):
|
||||
frames.append(torch.from_numpy(fbank.get_frame(i)))
|
||||
frames = torch.stack(frames)
|
||||
return frames
|
||||
|
||||
|
||||
def main():
|
||||
model_filename = "./model.onnx"
|
||||
model = OnnxModel(model_filename)
|
||||
|
||||
filename = "./0.wav"
|
||||
x = get_features(filename)
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
# Note: It supports only batch size == 1
|
||||
x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
|
||||
print(x.shape, x_lens)
|
||||
|
||||
log_probs, log_probs_lens = model(x, x_lens)
|
||||
log_probs = log_probs[0]
|
||||
print(log_probs.shape)
|
||||
|
||||
indexes = log_probs.argmax(dim=1)
|
||||
print(indexes)
|
||||
indexes = torch.unique_consecutive(indexes)
|
||||
indexes = indexes[indexes != 0].tolist()
|
||||
|
||||
id2word = dict()
|
||||
with open("./units.txt", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, idx = line.strip().split()
|
||||
id2word[int(idx)] = word
|
||||
text = "".join([id2word[i] for i in indexes])
|
||||
print(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user