See https://user-images.githubusercontent.com/5284924/282995968-f6d39118-8008-4ce7-9d7c-d1d6387ac183.png
105 lines
2.9 KiB
Python
Executable File
105 lines
2.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
|
import kaldi_native_fbank as knf
|
|
import onnxruntime as ort
|
|
import torch
|
|
import torchaudio
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
|
|
class OnnxModel:
|
|
def __init__(
|
|
self,
|
|
filename: str,
|
|
):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 4
|
|
|
|
self.session_opts = session_opts
|
|
|
|
self.model = ort.InferenceSession(
|
|
filename,
|
|
sess_options=self.session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x:
|
|
A 3-D tensor of shape (N, T, C)
|
|
x_lens:
|
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
|
Returns:
|
|
Return a 3-D tensor of shape (N, T, C) containing log_probs.
|
|
"""
|
|
log_probs, log_probs_lens = self.model.run(
|
|
[self.model.get_outputs()[0].name, self.model.get_outputs()[1].name],
|
|
{
|
|
self.model.get_inputs()[0].name: x.numpy(),
|
|
self.model.get_inputs()[1].name: x_lens.numpy(),
|
|
},
|
|
)
|
|
return torch.from_numpy(log_probs), torch.from_numpy(log_probs_lens)
|
|
|
|
|
|
def get_features(test_wav_filename):
|
|
wave, sample_rate = torchaudio.load(test_wav_filename)
|
|
audio = wave[0].contiguous() # only use the first channel
|
|
if sample_rate != 16000:
|
|
audio = torchaudio.functional.resample(
|
|
audio, orig_freq=sample_rate, new_freq=16000
|
|
)
|
|
audio *= 372768
|
|
|
|
opts = knf.FbankOptions()
|
|
opts.frame_opts.dither = 0
|
|
opts.mel_opts.num_bins = 80
|
|
opts.frame_opts.snip_edges = False
|
|
opts.mel_opts.debug_mel = False
|
|
|
|
fbank = knf.OnlineFbank(opts)
|
|
fbank.accept_waveform(16000, audio.numpy())
|
|
frames = []
|
|
for i in range(fbank.num_frames_ready):
|
|
frames.append(torch.from_numpy(fbank.get_frame(i)))
|
|
frames = torch.stack(frames)
|
|
return frames
|
|
|
|
|
|
def main():
|
|
model_filename = "./model.onnx"
|
|
model = OnnxModel(model_filename)
|
|
|
|
filename = "./0.wav"
|
|
x = get_features(filename)
|
|
x = x.unsqueeze(0)
|
|
|
|
# Note: It supports only batch size == 1
|
|
x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
|
|
|
|
print(x.shape, x_lens)
|
|
|
|
log_probs, log_probs_lens = model(x, x_lens)
|
|
log_probs = log_probs[0]
|
|
print(log_probs.shape)
|
|
|
|
indexes = log_probs.argmax(dim=1)
|
|
print(indexes)
|
|
indexes = torch.unique_consecutive(indexes)
|
|
indexes = indexes[indexes != 0].tolist()
|
|
|
|
id2word = dict()
|
|
with open("./units.txt", encoding="utf-8") as f:
|
|
for line in f:
|
|
word, idx = line.strip().split()
|
|
id2word[int(idx)] = word
|
|
text = "".join([id2word[i] for i in indexes])
|
|
print(text)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|