Export non-streaming NeMo faster conformer hybrid transducer and ctc to sherpa-onnx (#847)

This commit is contained in:
Fangjun Kuang
2024-05-09 13:59:47 +08:00
committed by GitHub
parent 68b25abf27
commit 5ed3ec1c04
16 changed files with 1055 additions and 20 deletions

1
scripts/nemo/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
!run-*.sh

View File

@@ -6,4 +6,20 @@ This folder contains scripts for exporting models from
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_480ms
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_1040ms
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_ctc_large
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_enes_conformer_transducer_large_codesw
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_transducer_large
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_enzh_fastconformer_transducer_large_codesw
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fa_fastconformer_hybrid_large
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_it_fastconformer_hybrid_large_pc
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_pl_fastconformer_hybrid_large_pc
- # https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_ua_fastconformer_hybrid_large_pc
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_pc
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_es_fastconformer_hybrid_large_pc
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc_blend_eu
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
to `sherpa-onnx`.

View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from typing import Dict
import nemo.collections.asr as nemo_asr
import onnx
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
)
parser.add_argument(
"--doc",
type=str,
default="",
)
return parser.parse_args()
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
args = get_args()
model_name = args.model
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
print(asr_model.cfg)
print(asr_model)
with open("./tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(asr_model.joint.vocabulary):
f.write(f"{s} {i}\n")
f.write(f"<blk> {i+1}\n")
print("Saved to tokens.txt")
decoder_type = "ctc"
asr_model.change_decoding_strategy(decoder_type=decoder_type)
asr_model.eval()
asr_model.set_export_config({"decoder_type": "ctc"})
filename = "model.onnx"
asr_model.export(filename)
normalize_type = asr_model.cfg.preprocessor.normalize
if normalize_type == "NA":
normalize_type = ""
meta_data = {
"vocab_size": asr_model.decoder.vocab_size,
"normalize_type": normalize_type,
"subsampling_factor": 8,
"model_type": "EncDecHybridRNNTCTCBPEModel",
"version": "1",
"model_author": "NeMo",
"url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}",
"comment": "Only the CTC branch is exported",
"doc": args.doc,
}
add_meta_data(filename, meta_data)
print("preprocessor", asr_model.cfg.preprocessor)
print(meta_data)
if __name__ == "__main__":
main()

View File

@@ -91,11 +91,15 @@ def main():
asr_model.export(filename)
normalize_type = asr_model.cfg.preprocessor.normalize
if normalize_type == "NA":
normalize_type = ""
meta_data = {
"vocab_size": asr_model.decoder.vocab_size,
"window_size": window_size,
"chunk_shift": chunk_shift,
"normalize_type": "None",
"normalize_type": normalize_type,
"cache_last_channel_dim1": cache_last_channel_dim1,
"cache_last_channel_dim2": cache_last_channel_dim2,
"cache_last_channel_dim3": cache_last_channel_dim3,

View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from typing import Dict
import nemo.collections.asr as nemo_asr
import onnx
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
)
parser.add_argument(
"--doc",
type=str,
default="",
)
return parser.parse_args()
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
args = get_args()
model_name = args.model
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
with open("./tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(asr_model.joint.vocabulary):
f.write(f"{s} {i}\n")
f.write(f"<blk> {i+1}\n")
print("Saved to tokens.txt")
decoder_type = "rnnt"
asr_model.change_decoding_strategy(decoder_type=decoder_type)
asr_model.eval()
asr_model.set_export_config({"decoder_type": "rnnt"})
# asr_model.export("model.onnx")
asr_model.encoder.export("encoder.onnx")
asr_model.decoder.export("decoder.onnx")
asr_model.joint.export("joiner.onnx")
# model.onnx is a suffix.
# It will generate two files:
# encoder-model.onnx
# decoder_joint-model.onnx
normalize_type = asr_model.cfg.preprocessor.normalize
if normalize_type == "NA":
normalize_type = ""
meta_data = {
"vocab_size": asr_model.decoder.vocab_size,
"normalize_type": normalize_type,
"pred_rnn_layers": asr_model.decoder.pred_rnn_layers,
"pred_hidden": asr_model.decoder.pred_hidden,
"subsampling_factor": 8,
"model_type": "EncDecHybridRNNTCTCBPEModel",
"version": "1",
"model_author": "NeMo",
"url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}",
"comment": "Only the transducer branch is exported",
"doc": args.doc,
}
add_meta_data("encoder.onnx", meta_data)
print(meta_data)
if __name__ == "__main__":
main()

View File

@@ -96,11 +96,15 @@ def main():
# encoder-model.onnx
# decoder_joint-model.onnx
normalize_type = asr_model.cfg.preprocessor.normalize
if normalize_type == "NA":
normalize_type = ""
meta_data = {
"vocab_size": asr_model.decoder.vocab_size,
"window_size": window_size,
"chunk_shift": chunk_shift,
"normalize_type": "None",
"normalize_type": normalize_type,
"cache_last_channel_dim1": cache_last_channel_dim1,
"cache_last_channel_dim2": cache_last_channel_dim2,
"cache_last_channel_dim3": cache_last_channel_dim3,

View File

@@ -0,0 +1,103 @@
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
# 8500 hours of English speech
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_pc
name=$(basename $url)
doc="This collection contains the English FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization on NeMo ASRSet En PC with around 8500 hours of English speech (SPGI 1k, VoxPopuli, MCV11, Europarl-ASR, Fisher, LibriSpeech, NSC1, MLS). It utilizes a Google SentencePiece [1] tokenizer with a vocabulary size of 1024. It transcribes text in upper and lower case English alphabet along with spaces, periods, commas, question marks, and a few other characters."
log "Process $name at $url"
./export-onnx-ctc-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-ctc-en-24500
mkdir -p $d
mv -v model.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_es_fastconformer_hybrid_large_pc
name=$(basename $url)
doc="This collection contains the Spanish FastConformer Hybrid (CTC and Transducer) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC ES ASRSET (Fisher, MCV12, MLS, Voxpopuli) containing 1424 hours of Spanish speech. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 1024, and transcribes text in upper and lower case Spanish alphabet along with spaces, period, comma, question mark and inverted question mark."
./export-onnx-ctc-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-ctc-es-1424
mkdir -p $d
mv -v model.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc_blend_eu
name=$(basename $url)
doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC German, English, Spanish, and French ASR sets that contain 14,288 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters. The total tokenizer size is 2560, of which 1024 tokens are allocated to English, German, French, and Spanish. The remaining tokens are reserved for future languages."
./export-onnx-ctc-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-ctc-en-de-es-fr-14288
mkdir -p $d
mv -v model.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
name=$(basename $url)
doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC Belarusian, German, English, Spanish, French, Croatian, Italian, Polish, Russian, and Ukrainian ASR sets that contain ~20,000 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language (2560 total), and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters."
./export-onnx-ctc-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k
mkdir -p $d
mv -v model.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
# Now test the exported model
log "Download test data"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2
tar xvf spoken-language-identification-test-wavs.tar.bz2
rm spoken-language-identification-test-wavs.tar.bz2
data=spoken-language-identification-test-wavs
d=sherpa-onnx-nemo-fast-conformer-ctc-en-24500
python3 ./test-onnx-ctc-non-streaming.py \
--model $d/model.onnx \
--tokens $d/tokens.txt \
--wav $data/en-english.wav
mkdir -p $d/test_wavs
cp -v $data/en-english.wav $d/test_wavs
d=sherpa-onnx-nemo-fast-conformer-ctc-es-1424
python3 ./test-onnx-ctc-non-streaming.py \
--model $d/model.onnx \
--tokens $d/tokens.txt \
--wav $data/es-spanish.wav
mkdir -p $d/test_wavs
cp -v $data/es-spanish.wav $d/test_wavs
d=sherpa-onnx-nemo-fast-conformer-ctc-en-de-es-fr-14288
mkdir -p $d/test_wavs
for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav; do
python3 ./test-onnx-ctc-non-streaming.py \
--model $d/model.onnx \
--tokens $d/tokens.txt \
--wav $data/$w
cp -v $data/$w $d/test_wavs
done
d=sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k
mkdir -p $d/test_wavs
for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav hr-croatian.wav it-italian.wav po-polish.wav ru-russian.wav uk-ukrainian.wav; do
python3 ./test-onnx-ctc-non-streaming.py \
--model $d/model.onnx \
--tokens $d/tokens.txt \
--wav $data/$w
cp -v $data/$w $d/test_wavs
done

View File

@@ -16,7 +16,7 @@ ms=(
for m in ${ms[@]}; do
./export-onnx-ctc.py --model $m
d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms
d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-${m}ms
if [ ! -f $d/model.onnx ]; then
mkdir -p $d
mv -v model.onnx $d/
@@ -28,7 +28,7 @@ done
# Now test the exported models
for m in ${ms[@]}; do
d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms
d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-${m}ms
python3 ./test-onnx-ctc.py \
--model $d/model.onnx \
--tokens $d/tokens.txt \

View File

@@ -0,0 +1,111 @@
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
# 8500 hours of English speech
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_pc
name=$(basename $url)
doc="This collection contains the English FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization on NeMo ASRSet En PC with around 8500 hours of English speech (SPGI 1k, VoxPopuli, MCV11, Europarl-ASR, Fisher, LibriSpeech, NSC1, MLS). It utilizes a Google SentencePiece [1] tokenizer with a vocabulary size of 1024. It transcribes text in upper and lower case English alphabet along with spaces, periods, commas, question marks, and a few other characters."
log "Process $name at $url"
./export-onnx-transducer-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-transducer-en-24500
mkdir -p $d
mv -v *.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_es_fastconformer_hybrid_large_pc
name=$(basename $url)
doc="This collection contains the Spanish FastConformer Hybrid (CTC and Transducer) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC ES ASRSET (Fisher, MCV12, MLS, Voxpopuli) containing 1424 hours of Spanish speech. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 1024, and transcribes text in upper and lower case Spanish alphabet along with spaces, period, comma, question mark and inverted question mark."
./export-onnx-transducer-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-transducer-es-1424
mkdir -p $d
mv -v *.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc_blend_eu
name=$(basename $url)
doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC German, English, Spanish, and French ASR sets that contain 14,288 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters. The total tokenizer size is 2560, of which 1024 tokens are allocated to English, German, French, and Spanish. The remaining tokens are reserved for future languages."
./export-onnx-transducer-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288
mkdir -p $d
mv -v *.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
url=https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
name=$(basename $url)
doc="This collection contains the Multilingual FastConformer Hybrid (Transducer and CTC) Large model (around 114M parameters) with Punctuation and Capitalization. It is trained on the NeMo PnC Belarusian, German, English, Spanish, French, Croatian, Italian, Polish, Russian, and Ukrainian ASR sets that contain ~20,000 hours of speech in total. It utilizes a Google SentencePiece [1] tokenizer with vocabulary size 256 per language (2560 total), and transcribes text in upper and lower case along with spaces, periods, commas, question marks and a few other language-specific characters."
./export-onnx-transducer-non-streaming.py --model $name --doc "$doc"
d=sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k
mkdir -p $d
mv -v *.onnx $d/
mv -v tokens.txt $d/
ls -lh $d
# Now test the exported model
log "Download test data"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2
tar xvf spoken-language-identification-test-wavs.tar.bz2
rm spoken-language-identification-test-wavs.tar.bz2
data=spoken-language-identification-test-wavs
d=sherpa-onnx-nemo-fast-conformer-transducer-en-24500
python3 ./test-onnx-transducer-non-streaming.py \
--encoder $d/encoder.onnx \
--decoder $d/decoder.onnx \
--joiner $d/joiner.onnx \
--tokens $d/tokens.txt \
--wav $data/en-english.wav
mkdir -p $d/test_wavs
cp -v $data/en-english.wav $d/test_wavs
d=sherpa-onnx-nemo-fast-conformer-transducer-es-1424
python3 ./test-onnx-transducer-non-streaming.py \
--encoder $d/encoder.onnx \
--decoder $d/decoder.onnx \
--joiner $d/joiner.onnx \
--tokens $d/tokens.txt \
--wav $data/es-spanish.wav
mkdir -p $d/test_wavs
cp -v $data/es-spanish.wav $d/test_wavs
d=sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288
mkdir -p $d/test_wavs
for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav; do
python3 ./test-onnx-transducer-non-streaming.py \
--encoder $d/encoder.onnx \
--decoder $d/decoder.onnx \
--joiner $d/joiner.onnx \
--tokens $d/tokens.txt \
--wav $data/$w
cp -v $data/$w $d/test_wavs
done
d=sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k
mkdir -p $d/test_wavs
for w in en-english.wav de-german.wav es-spanish.wav fr-french.wav hr-croatian.wav it-italian.wav po-polish.wav ru-russian.wav uk-ukrainian.wav; do
python3 ./test-onnx-transducer-non-streaming.py \
--encoder $d/encoder.onnx \
--decoder $d/decoder.onnx \
--joiner $d/joiner.onnx \
--tokens $d/tokens.txt \
--wav $data/$w
cp -v $data/$w $d/test_wavs
done

View File

@@ -16,7 +16,7 @@ ms=(
for m in ${ms[@]}; do
./export-onnx-transducer.py --model $m
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-${m}ms
if [ ! -f $d/encoder.onnx ]; then
mkdir -p $d
mv -v encoder.onnx $d/
@@ -30,7 +30,7 @@ done
# Now test the exported models
for m in ${ms[@]}; do
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms
d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-${m}ms
python3 ./test-onnx-transducer.py \
--encoder $d/encoder.onnx \
--decoder $d/decoder.onnx \

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from pathlib import Path
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import torch
import soundfile as sf
import librosa
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="Path to model.onnx")
parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt")
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
return parser.parse_args()
def create_fbank():
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.remove_dc_offset = False
opts.frame_opts.window_type = "hann"
opts.mel_opts.low_freq = 0
opts.mel_opts.num_bins = 80
opts.mel_opts.is_librosa = True
fbank = knf.OnlineFbank(opts)
return fbank
def compute_features(audio, fbank):
assert len(audio.shape) == 1, audio.shape
fbank.accept_waveform(16000, audio)
ans = []
processed = 0
while processed < fbank.num_frames_ready:
ans.append(np.array(fbank.get_frame(processed)))
processed += 1
ans = np.stack(ans)
return ans
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
print("==========Input==========")
for i in self.model.get_inputs():
print(i)
print("==========Output==========")
for i in self.model.get_outputs():
print(i)
"""
==========Input==========
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2'])
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
==========Output==========
NodeArg(name='logprobs', type='tensor(float)', shape=['logprobs_dynamic_axes_1', 'logprobs_dynamic_axes_2', 1025])
"""
meta = self.model.get_modelmeta().custom_metadata_map
self.normalize_type = meta["normalize_type"]
print(meta)
def __call__(self, x: np.ndarray):
# x: (T, C)
x = torch.from_numpy(x)
x = x.t().unsqueeze(0)
# x: [1, C, T]
x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
log_probs = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)[0]
# [batch_size, T, vocab_size]
return torch.from_numpy(log_probs)
def main():
args = get_args()
assert Path(args.model).is_file(), args.model
assert Path(args.tokens).is_file(), args.tokens
assert Path(args.wav).is_file(), args.wav
print(vars(args))
model = OnnxModel(args.model)
id2token = dict()
with open(args.tokens, encoding="utf-8") as f:
for line in f:
t, idx = line.split()
id2token[int(idx)] = t
fbank = create_fbank()
audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != 16000:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=16000,
)
sample_rate = 16000
blank = len(id2token) - 1
ans = []
prev = -1
print(audio.shape)
features = compute_features(audio, fbank)
if model.normalize_type != "":
assert model.normalize_type == "per_feature", model.normalize_type
features = torch.from_numpy(features)
mean = features.mean(dim=1, keepdims=True)
stddev = features.std(dim=1, keepdims=True)
features = (features - mean) / stddev
features = features.numpy()
print("features.shape", features.shape)
log_probs = model(features)
print("log_probs.shape", log_probs.shape)
log_probs = log_probs[0, :, :] # remove batch dim
ids = torch.argmax(log_probs, dim=1).tolist()
for k in ids:
if k != blank and k != prev:
ans.append(k)
prev = k
tokens = [id2token[i] for i in ans]
underline = ""
# underline = b"\xe2\x96\x81".decode()
text = "".join(tokens).replace(underline, " ").strip()
print(args.wav)
print(text)
main()

View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from pathlib import Path
import kaldi_native_fbank as knf
import librosa
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder", type=str, required=True, help="Path to encoder.onnx"
)
parser.add_argument(
"--decoder", type=str, required=True, help="Path to decoder.onnx"
)
parser.add_argument("--joiner", type=str, required=True, help="Path to joiner.onnx")
parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt")
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
return parser.parse_args()
def create_fbank():
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.remove_dc_offset = False
opts.frame_opts.window_type = "hann"
opts.mel_opts.low_freq = 0
opts.mel_opts.num_bins = 80
opts.mel_opts.is_librosa = True
fbank = knf.OnlineFbank(opts)
return fbank
def compute_features(audio, fbank):
assert len(audio.shape) == 1, audio.shape
fbank.accept_waveform(16000, audio)
ans = []
processed = 0
while processed < fbank.num_frames_ready:
ans.append(np.array(fbank.get_frame(processed)))
processed += 1
ans = np.stack(ans)
return ans
def display(sess):
print("==========Input==========")
for i in sess.get_inputs():
print(i)
print("==========Output==========")
for i in sess.get_outputs():
print(i)
"""
encoder
==========Input==========
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2'])
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
==========Output==========
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2'])
NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1'])
decoder
==========Input==========
NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2'])
NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1'])
NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640])
NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640])
==========Output==========
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2'])
NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1'])
NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640])
NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640])
joiner
==========Input==========
NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2'])
NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2'])
==========Output==========
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025])
"""
class OnnxModel:
def __init__(
self,
encoder: str,
decoder: str,
joiner: str,
):
self.init_encoder(encoder)
display(self.encoder)
self.init_decoder(decoder)
display(self.decoder)
self.init_joiner(joiner)
display(self.joiner)
def init_encoder(self, encoder):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.encoder = ort.InferenceSession(
encoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.encoder.get_modelmeta().custom_metadata_map
self.normalize_type = meta["normalize_type"]
print(meta)
self.pred_rnn_layers = int(meta["pred_rnn_layers"])
self.pred_hidden = int(meta["pred_hidden"])
def init_decoder(self, decoder):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.decoder = ort.InferenceSession(
decoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_joiner(self, joiner):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.joiner = ort.InferenceSession(
joiner,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def get_decoder_state(self):
batch_size = 1
state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy()
state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy()
return state0, state1
def run_encoder(self, x: np.ndarray):
# x: (T, C)
x = torch.from_numpy(x)
x = x.t().unsqueeze(0)
# x: [1, C, T]
x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
(encoder_out, out_len) = self.encoder.run(
[
self.encoder.get_outputs()[0].name,
self.encoder.get_outputs()[1].name,
],
{
self.encoder.get_inputs()[0].name: x.numpy(),
self.encoder.get_inputs()[1].name: x_lens.numpy(),
},
)
# [batch_size, dim, T]
return encoder_out
def run_decoder(
self,
token: int,
state0: np.ndarray,
state1: np.ndarray,
):
target = torch.tensor([[token]], dtype=torch.int32).numpy()
target_len = torch.tensor([1], dtype=torch.int32).numpy()
(
decoder_out,
decoder_out_length,
state0_next,
state1_next,
) = self.decoder.run(
[
self.decoder.get_outputs()[0].name,
self.decoder.get_outputs()[1].name,
self.decoder.get_outputs()[2].name,
self.decoder.get_outputs()[3].name,
],
{
self.decoder.get_inputs()[0].name: target,
self.decoder.get_inputs()[1].name: target_len,
self.decoder.get_inputs()[2].name: state0,
self.decoder.get_inputs()[3].name: state1,
},
)
return decoder_out, state0_next, state1_next
def run_joiner(
self,
encoder_out: np.ndarray,
decoder_out: np.ndarray,
):
# encoder_out: [batch_size, dim, 1]
# decoder_out: [batch_size, dim, 1]
logit = self.joiner.run(
[
self.joiner.get_outputs()[0].name,
],
{
self.joiner.get_inputs()[0].name: encoder_out,
self.joiner.get_inputs()[1].name: decoder_out,
},
)[0]
# logit: [batch_size, 1, 1, vocab_size]
return logit
def main():
args = get_args()
assert Path(args.encoder).is_file(), args.encoder
assert Path(args.decoder).is_file(), args.decoder
assert Path(args.joiner).is_file(), args.joiner
assert Path(args.tokens).is_file(), args.tokens
assert Path(args.wav).is_file(), args.wav
print(vars(args))
model = OnnxModel(args.encoder, args.decoder, args.joiner)
id2token = dict()
with open(args.tokens, encoding="utf-8") as f:
for line in f:
t, idx = line.split()
id2token[int(idx)] = t
fbank = create_fbank()
audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != 16000:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=16000,
)
sample_rate = 16000
tail_padding = np.zeros(sample_rate * 2)
audio = np.concatenate([audio, tail_padding])
blank = len(id2token) - 1
ans = [blank]
state0, state1 = model.get_decoder_state()
decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1)
features = compute_features(audio, fbank)
if model.normalize_type != "":
assert model.normalize_type == "per_feature", model.normalize_type
features = torch.from_numpy(features)
mean = features.mean(dim=1, keepdims=True)
stddev = features.std(dim=1, keepdims=True)
features = (features - mean) / stddev
features = features.numpy()
print(audio.shape)
print("features.shape", features.shape)
encoder_out = model.run_encoder(features)
# encoder_out:[batch_size, dim, T)
for t in range(encoder_out.shape[2]):
encoder_out_t = encoder_out[:, :, t : t + 1]
logits = model.run_joiner(encoder_out_t, decoder_out)
logits = torch.from_numpy(logits)
logits = logits.squeeze()
idx = torch.argmax(logits, dim=-1).item()
if idx != blank:
ans.append(idx)
state0 = state0_next
state1 = state1_next
decoder_out, state0_next, state1_next = model.run_decoder(
ans[-1], state0, state1
)
ans = ans[1:] # remove the first blank
print(ans)
tokens = [id2token[i] for i in ans]
underline = ""
# underline = b"\xe2\x96\x81".decode()
text = "".join(tokens).replace(underline, " ").strip()
print(args.wav)
print(text)
main()