diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index f0978930..160478f2 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -47,9 +47,23 @@ for type in base small; do rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02 done +log "------------------------------------------------------------" +log "Run NeMo GigaAM Russian models v2" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2 +tar xvf sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2 +rm sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2 + +$EXE \ + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/model.int8.onnx \ + --tokens=./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/tokens.txt \ + --debug=1 \ + ./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/test_wavs/example.wav + +rm -rf sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19 log "------------------------------------------------------------" -log "Run NeMo GigaAM Russian models" +log "Run NeMo GigaAM Russian models v1" log "------------------------------------------------------------" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 7ac72986..9c69bed0 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -15,6 +15,24 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run NeMo GigaAM Russian models v2" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2 +tar xvf sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2 +rm sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2 + +$EXE \ + --encoder=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/encoder.int8.onnx \ + --decoder=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/decoder.onnx \ + --joiner=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/joiner.onnx \ + --tokens=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/tokens.txt \ + --model-type=nemo_transducer \ + ./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/test_wavs/example.wav + +rm sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19 + + log "------------------------------------------------------------------------" log "Run zipformer transducer models (Russian) " log "------------------------------------------------------------------------" diff --git a/.github/workflows/export-nemo-giga-am-to-onnx.yaml b/.github/workflows/export-nemo-giga-am-to-onnx.yaml index 49bc3cf4..2636d73c 100644 --- a/.github/workflows/export-nemo-giga-am-to-onnx.yaml +++ b/.github/workflows/export-nemo-giga-am-to-onnx.yaml @@ -43,7 +43,8 @@ jobs: mv -v scripts/nemo/GigaAM/tokens.txt $d/ mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ - mv -v scripts/nemo/GigaAM/*-ctc.py $d/ + mv -v scripts/nemo/GigaAM/export-onnx-ctc.py $d/ + cp -v scripts/nemo/GigaAM/test-onnx-ctc.py $d/ ls -lh scripts/nemo/GigaAM/ @@ -71,7 +72,8 @@ jobs: mv -v scripts/nemo/GigaAM/tokens.txt $d/ mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ mv -v scripts/nemo/GigaAM/run-rnnt.sh $d/ - mv -v scripts/nemo/GigaAM/*-rnnt.py $d/ + mv -v scripts/nemo/GigaAM/export-onnx-rnnt.py $d/ + cp -v scripts/nemo/GigaAM/test-onnx-rnnt.py $d/ ls -lh scripts/nemo/GigaAM/ @@ -91,11 +93,12 @@ jobs: mkdir $d/test_wavs rm scripts/nemo/GigaAM/v2_ctc.onnx mv -v scripts/nemo/GigaAM/*.int8.onnx $d/ - cp -v scripts/nemo/GigaAM/LICENCE $d/ + cp -v scripts/nemo/GigaAM/LICENSE $d/ mv -v scripts/nemo/GigaAM/tokens.txt $d/ mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ - mv -v scripts/nemo/GigaAM/run-ctc.sh $d/ + mv -v scripts/nemo/GigaAM/run-ctc-v2.sh $d/ mv -v scripts/nemo/GigaAM/*-ctc-v2.py $d/ + cp -v scripts/nemo/GigaAM/test-onnx-ctc.py $d/ ls -lh scripts/nemo/GigaAM/ @@ -103,8 +106,36 @@ jobs: tar cjvf ${d}.tar.bz2 $d + - name: Run Transducer v2 + shell: bash + run: | + pushd scripts/nemo/GigaAM + ./run-rnnt-v2.sh + popd + + d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19 + mkdir $d + mkdir $d/test_wavs + + mv -v scripts/nemo/GigaAM/encoder.int8.onnx $d/ + mv -v scripts/nemo/GigaAM/decoder.onnx $d/ + mv -v scripts/nemo/GigaAM/joiner.onnx $d/ + + cp -v scripts/nemo/GigaAM/*.md $d/ + cp -v scripts/nemo/GigaAM/LICENSE $d/ + mv -v scripts/nemo/GigaAM/tokens.txt $d/ + mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/ + mv -v scripts/nemo/GigaAM/run-rnnt-v2.sh $d/ + cp -v scripts/nemo/GigaAM/test-onnx-rnnt.py $d/ + + ls -lh scripts/nemo/GigaAM/ + + ls -lh $d + + tar cjvf ${d}.tar.bz2 $d - name: Release + if: github.repository_owner == 'csukuangfj' uses: svenstaro/upload-release-action@v2 with: file_glob: true @@ -114,7 +145,16 @@ jobs: repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} tag: asr-models - - name: Publish to huggingface (Transducer) + - name: Release + if: github.repository_owner == 'k2-fsa' + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + tag: asr-models + + - name: Publish to huggingface (CTC) env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 @@ -126,11 +166,66 @@ jobs: git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" + d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/ + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + rm -rf huggingface + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface + cp -av $d/* ./huggingface + cd huggingface + git lfs track "*.onnx" + git lfs track "*.wav" + git status + git add . + git status + git commit -m "add models" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main + + - name: Publish to huggingface (Transducer) + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 5 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24/ export GIT_LFS_SKIP_SMUDGE=1 export GIT_CLONE_PROTECTION_ACTIVE=false + rm -rf huggingface git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface - mv -v $d/* ./huggingface + cp -av $d/* ./huggingface + cd huggingface + git lfs track "*.onnx" + git lfs track "*.wav" + git status + git add . + git status + git commit -m "add models" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main + + - name: Publish v2 to huggingface (CTC) + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 5 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + d=sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/ + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + rm -rf huggingface + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface + cp -av $d/* ./huggingface cd huggingface git lfs track "*.onnx" git lfs track "*.wav" @@ -145,7 +240,7 @@ jobs: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 with: - max_attempts: 20 + max_attempts: 5 timeout_seconds: 200 shell: bash command: | @@ -155,8 +250,9 @@ jobs: d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/ export GIT_LFS_SKIP_SMUDGE=1 export GIT_CLONE_PROTECTION_ACTIVE=false + rm -rf huggingface git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface - mv -v $d/* ./huggingface + cp -av $d/* ./huggingface cd huggingface git lfs track "*.onnx" git lfs track "*.wav" diff --git a/scripts/nemo/GigaAM/README.md b/scripts/nemo/GigaAM/README.md index 583d10a1..287b7907 100644 --- a/scripts/nemo/GigaAM/README.md +++ b/scripts/nemo/GigaAM/README.md @@ -7,4 +7,4 @@ to sherpa-onnx. The ASR models are for Russian speech recognition in this folder. Please see the license of the models at -https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf +https://github.com/salute-developers/GigaAM/blob/main/LICENSE diff --git a/scripts/nemo/GigaAM/export-onnx-ctc-v2.py b/scripts/nemo/GigaAM/export-onnx-ctc-v2.py old mode 100644 new mode 100755 index e6f95c0f..047cd220 --- a/scripts/nemo/GigaAM/export-onnx-ctc-v2.py +++ b/scripts/nemo/GigaAM/export-onnx-ctc-v2.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 import gigaam import onnx import torch @@ -27,7 +28,13 @@ def add_meta_data(filename: str, meta_data: dict[str, str]): def main() -> None: model_name = "v2_ctc" - model = gigaam.load_model(model_name, fp16_encoder=False, use_flash=False, download_root=".") + model = gigaam.load_model( + model_name, fp16_encoder=False, use_flash=False, download_root="." + ) + + # use characters + # space is 0 + # is the last token with open("./tokens.txt", "w", encoding="utf-8") as f: for i, s in enumerate(model.cfg["labels"]): f.write(f"{s} {i}\n") @@ -53,5 +60,5 @@ def main() -> None: ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/nemo/GigaAM/export-onnx-ctc.py b/scripts/nemo/GigaAM/export-onnx-ctc.py index 81feb3b7..473361a7 100755 --- a/scripts/nemo/GigaAM/export-onnx-ctc.py +++ b/scripts/nemo/GigaAM/export-onnx-ctc.py @@ -82,6 +82,9 @@ def main(): model.load_state_dict(ckpt, strict=False) model.eval() + # use characters + # space is 0 + # is the last token with open("tokens.txt", "w", encoding="utf-8") as f: for i, t in enumerate(model.cfg.labels): f.write(f"{t} {i}\n") diff --git a/scripts/nemo/GigaAM/export-onnx-rnnt-v2.py b/scripts/nemo/GigaAM/export-onnx-rnnt-v2.py new file mode 100755 index 00000000..bb194cf2 --- /dev/null +++ b/scripts/nemo/GigaAM/export-onnx-rnnt-v2.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) +import os + +import gigaam +import onnx +import torch +from gigaam.utils import onnx_converter +from onnxruntime.quantization import QuantType, quantize_dynamic +from torch import Tensor + +""" +==========Input========== +NodeArg(name='audio_signal', type='tensor(float)', shape=['batch_size', 64, 'seq_len']) +NodeArg(name='length', type='tensor(int64)', shape=['batch_size']) +==========Output========== +NodeArg(name='encoded', type='tensor(float)', shape=['batch_size', 768, 'Transposeencoded_dim_2']) +NodeArg(name='encoded_len', type='tensor(int32)', shape=['batch_size']) + +==========Input========== +NodeArg(name='x', type='tensor(int32)', shape=[1, 1]) +NodeArg(name='unused_x_len.1', type='tensor(int32)', shape=[1]) +NodeArg(name='h.1', type='tensor(float)', shape=[1, 1, 320]) +NodeArg(name='c.1', type='tensor(float)', shape=[1, 1, 320]) +==========Output========== +NodeArg(name='dec', type='tensor(float)', shape=[1, 320, 1]) +NodeArg(name='unused_x_len', type='tensor(int32)', shape=[1]) +NodeArg(name='h', type='tensor(float)', shape=[1, 1, 320]) +NodeArg(name='c', type='tensor(float)', shape=[1, 1, 320]) + +==========Input========== +NodeArg(name='enc', type='tensor(float)', shape=[1, 768, 1]) +NodeArg(name='dec', type='tensor(float)', shape=[1, 320, 1]) +==========Output========== +NodeArg(name='joint', type='tensor(float)', shape=[1, 1, 1, 34]) +""" + + +def add_meta_data(filename: str, meta_data: dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class EncoderWrapper(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.m = m + + def forward(self, audio_signal: Tensor, length: Tensor): + # https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py#L499 + out, out_len = self.m.encoder(audio_signal, length) + + return out, out_len.to(torch.int64) + + def to_onnx(self, dir_path: str = "."): + onnx_converter( + model_name=f"{self.m.cfg.model_name}_encoder", + out_dir=dir_path, + module=self.m.encoder, + dynamic_axes=self.m.encoder.dynamic_axes(), + ) + + +class DecoderWrapper(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.m = m + + def forward(self, x: Tensor, unused_x_len: Tensor, h: Tensor, c: Tensor): + # https://github.com/salute-developers/GigaAM/blob/main/gigaam/decoder.py#L110C17-L110C54 + emb = self.m.head.decoder.embed(x) + g, (h, c) = self.m.head.decoder.lstm(emb.transpose(0, 1), (h, c)) + return g.permute(1, 2, 0), unused_x_len + 1, h, c + + def to_onnx(self, dir_path: str = "."): + label, hidden_h, hidden_c = self.m.head.decoder.input_example() + label = label.to(torch.int32) + label_len = torch.zeros(1, dtype=torch.int32) + + onnx_converter( + model_name=f"{self.m.cfg.model_name}_decoder", + out_dir=dir_path, + module=self, + dynamic_axes=self.m.encoder.dynamic_axes(), + inputs=(label, label_len, hidden_h, hidden_c), + input_names=["x", "unused_x_len.1", "h.1", "c.1"], + output_names=["dec", "unused_x_len", "h", "c"], + ) + + +def main() -> None: + model_name = "v2_rnnt" + model = gigaam.load_model( + model_name, fp16_encoder=False, use_flash=False, download_root="." + ) + + # use characters + # space is 0 + # is the last token + with open("./tokens.txt", "w", encoding="utf-8") as f: + for i, s in enumerate(model.cfg["labels"]): + f.write(f"{s} {i}\n") + f.write(f" {i+1}\n") + print("Saved to tokens.txt") + + EncoderWrapper(model).to_onnx(".") + DecoderWrapper(model).to_onnx(".") + + onnx_converter( + model_name=f"{model.cfg.model_name}_joint", + out_dir=".", + module=model.head.joint, + ) + meta_data = { + # vocab_size does not include the blank + # we will increase vocab_size by 1 in the c++ code + "vocab_size": model.cfg["head"]["decoder"]["num_classes"] - 1, + "pred_rnn_layers": model.cfg["head"]["decoder"]["pred_rnn_layers"], + "pred_hidden": model.cfg["head"]["decoder"]["pred_hidden"], + "normalize_type": "", + "subsampling_factor": 4, + "model_type": "EncDecRNNTBPEModel", + "version": "2", + "model_author": "https://github.com/salute-developers/GigaAM", + "license": "https://github.com/salute-developers/GigaAM/blob/main/LICENSE", + "language": "Russian", + "is_giga_am": 1, + } + + add_meta_data(f"./{model_name}_encoder.onnx", meta_data) + quantize_dynamic( + model_input=f"./{model_name}_encoder.onnx", + model_output="./encoder.int8.onnx", + weight_type=QuantType.QUInt8, + ) + os.rename(f"./{model_name}_decoder.onnx", "decoder.onnx") + os.rename(f"./{model_name}_joint.onnx", "joiner.onnx") + os.remove(f"./{model_name}_encoder.onnx") + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/GigaAM/export-onnx-rnnt.py b/scripts/nemo/GigaAM/export-onnx-rnnt.py old mode 100644 new mode 100755 index 1ac05ff7..1c89c773 --- a/scripts/nemo/GigaAM/export-onnx-rnnt.py +++ b/scripts/nemo/GigaAM/export-onnx-rnnt.py @@ -83,6 +83,7 @@ def main(): model.load_state_dict(ckpt, strict=False) model.eval() + # use bpe with open("./tokens.txt", "w", encoding="utf-8") as f: for i, s in enumerate(model.joint.vocabulary): f.write(f"{s} {i}\n") @@ -94,7 +95,9 @@ def main(): model.joint.export("joiner.onnx") meta_data = { - "vocab_size": model.decoder.vocab_size, # not including the blank + # not including the blank + # we increase vocab_size in the C++ code + "vocab_size": model.decoder.vocab_size, "pred_rnn_layers": model.decoder.pred_rnn_layers, "pred_hidden": model.decoder.pred_hidden, "normalize_type": "", diff --git a/scripts/nemo/GigaAM/run-ctc-v2.sh b/scripts/nemo/GigaAM/run-ctc-v2.sh index db133037..4dc4e3f8 100755 --- a/scripts/nemo/GigaAM/run-ctc-v2.sh +++ b/scripts/nemo/GigaAM/run-ctc-v2.sh @@ -5,11 +5,14 @@ set -ex function install_gigaam() { curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py python3 get-pip.py + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa BRANCH='main' python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam python3 -m pip install -qq kaldi-native-fbank + pip install numpy==1.26.4 } function download_files() { diff --git a/scripts/nemo/GigaAM/run-ctc.sh b/scripts/nemo/GigaAM/run-ctc.sh index 03acc88e..26044d2b 100755 --- a/scripts/nemo/GigaAM/run-ctc.sh +++ b/scripts/nemo/GigaAM/run-ctc.sh @@ -9,7 +9,7 @@ function install_nemo() { pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html - pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa pip install -qq ipython # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython diff --git a/scripts/nemo/GigaAM/run-rnnt-v2.sh b/scripts/nemo/GigaAM/run-rnnt-v2.sh new file mode 100755 index 00000000..bc9fa82e --- /dev/null +++ b/scripts/nemo/GigaAM/run-rnnt-v2.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +set -ex + +function install_gigaam() { + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py + python3 get-pip.py + pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa + + BRANCH='main' + python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam + + python3 -m pip install -qq kaldi-native-fbank + pip install numpy==1.26.4 +} + +function download_files() { + curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/example.wav + curl -SL -O https://github.com/salute-developers/GigaAM/blob/main/LICENSE +} + +install_gigaam +download_files + +python3 ./export-onnx-rnnt-v2.py +ls -lh +python3 ./test-onnx-rnnt.py diff --git a/scripts/nemo/GigaAM/run-rnnt.sh b/scripts/nemo/GigaAM/run-rnnt.sh index 209f4f15..a84e6f1c 100755 --- a/scripts/nemo/GigaAM/run-rnnt.sh +++ b/scripts/nemo/GigaAM/run-rnnt.sh @@ -9,7 +9,7 @@ function install_nemo() { pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html - pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa + pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa pip install -qq ipython # sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython diff --git a/scripts/nemo/GigaAM/test-onnx-ctc.py b/scripts/nemo/GigaAM/test-onnx-ctc.py index 5c181e49..731d4e7e 100755 --- a/scripts/nemo/GigaAM/test-onnx-ctc.py +++ b/scripts/nemo/GigaAM/test-onnx-ctc.py @@ -19,7 +19,7 @@ def create_fbank(): opts.frame_opts.window_type = "hann" # Even though GigaAM uses 400 for fft, here we use 512 - # since kaldi-native-fbank only support fft for power of 2. + # since kaldi-native-fbank only supports fft for power of 2. opts.frame_opts.round_to_power_of_two = True opts.mel_opts.low_freq = 0 diff --git a/scripts/nemo/GigaAM/test-onnx-rnnt.py b/scripts/nemo/GigaAM/test-onnx-rnnt.py index 85c6a5e9..f2bf7b2a 100755 --- a/scripts/nemo/GigaAM/test-onnx-rnnt.py +++ b/scripts/nemo/GigaAM/test-onnx-rnnt.py @@ -20,7 +20,7 @@ def create_fbank(): opts.frame_opts.window_type = "hann" # Even though GigaAM uses 400 for fft, here we use 512 - # since kaldi-native-fbank only support fft for power of 2. + # since kaldi-native-fbank only supports fft for power of 2. opts.frame_opts.round_to_power_of_two = True opts.mel_opts.low_freq = 0 @@ -166,12 +166,7 @@ class OnnxModel: target = torch.tensor([[token]], dtype=torch.int32).numpy() target_len = torch.tensor([1], dtype=torch.int32).numpy() - ( - decoder_out, - decoder_out_length, - state0_next, - state1_next, - ) = self.decoder.run( + (decoder_out, decoder_out_length, state0_next, state1_next,) = self.decoder.run( [ self.decoder.get_outputs()[0].name, self.decoder.get_outputs()[1].name, @@ -213,8 +208,12 @@ def main(): id2token = dict() with open("./tokens.txt", encoding="utf-8") as f: for line in f: - t, idx = line.split() - id2token[int(idx)] = t + fields = line.split() + if len(fields) == 1: + id2token[int(fields[0])] = " " + else: + t, idx = fields + id2token[int(idx)] = t fbank = create_fbank() audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True)