Support Giga AM transducer V2 (#2136)
This commit is contained in:
16
.github/scripts/test-offline-ctc.sh
vendored
16
.github/scripts/test-offline-ctc.sh
vendored
@@ -47,9 +47,23 @@ for type in base small; do
|
|||||||
rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02
|
rm -rf sherpa-onnx-dolphin-$type-ctc-multi-lang-2025-04-02
|
||||||
done
|
done
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run NeMo GigaAM Russian models v2"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2
|
||||||
|
rm sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19.tar.bz2
|
||||||
|
|
||||||
|
$EXE \
|
||||||
|
--nemo-ctc-model=./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/model.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/tokens.txt \
|
||||||
|
--debug=1 \
|
||||||
|
./sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/test_wavs/example.wav
|
||||||
|
|
||||||
|
rm -rf sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19
|
||||||
|
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
log "Run NeMo GigaAM Russian models"
|
log "Run NeMo GigaAM Russian models v1"
|
||||||
log "------------------------------------------------------------"
|
log "------------------------------------------------------------"
|
||||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
|
||||||
tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
|
tar xvf sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2
|
||||||
|
|||||||
18
.github/scripts/test-offline-transducer.sh
vendored
18
.github/scripts/test-offline-transducer.sh
vendored
@@ -15,6 +15,24 @@ echo "PATH: $PATH"
|
|||||||
|
|
||||||
which $EXE
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run NeMo GigaAM Russian models v2"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2
|
||||||
|
rm sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19.tar.bz2
|
||||||
|
|
||||||
|
$EXE \
|
||||||
|
--encoder=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/encoder.int8.onnx \
|
||||||
|
--decoder=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/decoder.onnx \
|
||||||
|
--joiner=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/joiner.onnx \
|
||||||
|
--tokens=./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/tokens.txt \
|
||||||
|
--model-type=nemo_transducer \
|
||||||
|
./sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/test_wavs/example.wav
|
||||||
|
|
||||||
|
rm sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19
|
||||||
|
|
||||||
|
|
||||||
log "------------------------------------------------------------------------"
|
log "------------------------------------------------------------------------"
|
||||||
log "Run zipformer transducer models (Russian) "
|
log "Run zipformer transducer models (Russian) "
|
||||||
log "------------------------------------------------------------------------"
|
log "------------------------------------------------------------------------"
|
||||||
|
|||||||
112
.github/workflows/export-nemo-giga-am-to-onnx.yaml
vendored
112
.github/workflows/export-nemo-giga-am-to-onnx.yaml
vendored
@@ -43,7 +43,8 @@ jobs:
|
|||||||
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
||||||
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
||||||
mv -v scripts/nemo/GigaAM/run-ctc.sh $d/
|
mv -v scripts/nemo/GigaAM/run-ctc.sh $d/
|
||||||
mv -v scripts/nemo/GigaAM/*-ctc.py $d/
|
mv -v scripts/nemo/GigaAM/export-onnx-ctc.py $d/
|
||||||
|
cp -v scripts/nemo/GigaAM/test-onnx-ctc.py $d/
|
||||||
|
|
||||||
ls -lh scripts/nemo/GigaAM/
|
ls -lh scripts/nemo/GigaAM/
|
||||||
|
|
||||||
@@ -71,7 +72,8 @@ jobs:
|
|||||||
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
||||||
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
||||||
mv -v scripts/nemo/GigaAM/run-rnnt.sh $d/
|
mv -v scripts/nemo/GigaAM/run-rnnt.sh $d/
|
||||||
mv -v scripts/nemo/GigaAM/*-rnnt.py $d/
|
mv -v scripts/nemo/GigaAM/export-onnx-rnnt.py $d/
|
||||||
|
cp -v scripts/nemo/GigaAM/test-onnx-rnnt.py $d/
|
||||||
|
|
||||||
ls -lh scripts/nemo/GigaAM/
|
ls -lh scripts/nemo/GigaAM/
|
||||||
|
|
||||||
@@ -91,11 +93,12 @@ jobs:
|
|||||||
mkdir $d/test_wavs
|
mkdir $d/test_wavs
|
||||||
rm scripts/nemo/GigaAM/v2_ctc.onnx
|
rm scripts/nemo/GigaAM/v2_ctc.onnx
|
||||||
mv -v scripts/nemo/GigaAM/*.int8.onnx $d/
|
mv -v scripts/nemo/GigaAM/*.int8.onnx $d/
|
||||||
cp -v scripts/nemo/GigaAM/LICENCE $d/
|
cp -v scripts/nemo/GigaAM/LICENSE $d/
|
||||||
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
||||||
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
||||||
mv -v scripts/nemo/GigaAM/run-ctc.sh $d/
|
mv -v scripts/nemo/GigaAM/run-ctc-v2.sh $d/
|
||||||
mv -v scripts/nemo/GigaAM/*-ctc-v2.py $d/
|
mv -v scripts/nemo/GigaAM/*-ctc-v2.py $d/
|
||||||
|
cp -v scripts/nemo/GigaAM/test-onnx-ctc.py $d/
|
||||||
|
|
||||||
ls -lh scripts/nemo/GigaAM/
|
ls -lh scripts/nemo/GigaAM/
|
||||||
|
|
||||||
@@ -103,8 +106,36 @@ jobs:
|
|||||||
|
|
||||||
tar cjvf ${d}.tar.bz2 $d
|
tar cjvf ${d}.tar.bz2 $d
|
||||||
|
|
||||||
|
- name: Run Transducer v2
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pushd scripts/nemo/GigaAM
|
||||||
|
./run-rnnt-v2.sh
|
||||||
|
popd
|
||||||
|
|
||||||
|
d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19
|
||||||
|
mkdir $d
|
||||||
|
mkdir $d/test_wavs
|
||||||
|
|
||||||
|
mv -v scripts/nemo/GigaAM/encoder.int8.onnx $d/
|
||||||
|
mv -v scripts/nemo/GigaAM/decoder.onnx $d/
|
||||||
|
mv -v scripts/nemo/GigaAM/joiner.onnx $d/
|
||||||
|
|
||||||
|
cp -v scripts/nemo/GigaAM/*.md $d/
|
||||||
|
cp -v scripts/nemo/GigaAM/LICENSE $d/
|
||||||
|
mv -v scripts/nemo/GigaAM/tokens.txt $d/
|
||||||
|
mv -v scripts/nemo/GigaAM/*.wav $d/test_wavs/
|
||||||
|
mv -v scripts/nemo/GigaAM/run-rnnt-v2.sh $d/
|
||||||
|
cp -v scripts/nemo/GigaAM/test-onnx-rnnt.py $d/
|
||||||
|
|
||||||
|
ls -lh scripts/nemo/GigaAM/
|
||||||
|
|
||||||
|
ls -lh $d
|
||||||
|
|
||||||
|
tar cjvf ${d}.tar.bz2 $d
|
||||||
|
|
||||||
- name: Release
|
- name: Release
|
||||||
|
if: github.repository_owner == 'csukuangfj'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
@@ -114,7 +145,16 @@ jobs:
|
|||||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
tag: asr-models
|
tag: asr-models
|
||||||
|
|
||||||
- name: Publish to huggingface (Transducer)
|
- name: Release
|
||||||
|
if: github.repository_owner == 'k2-fsa'
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: ./*.tar.bz2
|
||||||
|
overwrite: true
|
||||||
|
tag: asr-models
|
||||||
|
|
||||||
|
- name: Publish to huggingface (CTC)
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
uses: nick-fields/retry@v3
|
uses: nick-fields/retry@v3
|
||||||
@@ -126,11 +166,66 @@ jobs:
|
|||||||
git config --global user.email "csukuangfj@gmail.com"
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
git config --global user.name "Fangjun Kuang"
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
d=sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24/
|
||||||
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
rm -rf huggingface
|
||||||
|
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||||
|
cp -av $d/* ./huggingface
|
||||||
|
cd huggingface
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git lfs track "*.wav"
|
||||||
|
git status
|
||||||
|
git add .
|
||||||
|
git status
|
||||||
|
git commit -m "add models"
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
|
||||||
|
|
||||||
|
- name: Publish to huggingface (Transducer)
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
uses: nick-fields/retry@v3
|
||||||
|
with:
|
||||||
|
max_attempts: 5
|
||||||
|
timeout_seconds: 200
|
||||||
|
shell: bash
|
||||||
|
command: |
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24/
|
d=sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24/
|
||||||
export GIT_LFS_SKIP_SMUDGE=1
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
rm -rf huggingface
|
||||||
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||||
mv -v $d/* ./huggingface
|
cp -av $d/* ./huggingface
|
||||||
|
cd huggingface
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git lfs track "*.wav"
|
||||||
|
git status
|
||||||
|
git add .
|
||||||
|
git status
|
||||||
|
git commit -m "add models"
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
|
||||||
|
|
||||||
|
- name: Publish v2 to huggingface (CTC)
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
uses: nick-fields/retry@v3
|
||||||
|
with:
|
||||||
|
max_attempts: 5
|
||||||
|
timeout_seconds: 200
|
||||||
|
shell: bash
|
||||||
|
command: |
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
d=sherpa-onnx-nemo-ctc-giga-am-v2-russian-2025-04-19/
|
||||||
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
rm -rf huggingface
|
||||||
|
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||||
|
cp -av $d/* ./huggingface
|
||||||
cd huggingface
|
cd huggingface
|
||||||
git lfs track "*.onnx"
|
git lfs track "*.onnx"
|
||||||
git lfs track "*.wav"
|
git lfs track "*.wav"
|
||||||
@@ -145,7 +240,7 @@ jobs:
|
|||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
uses: nick-fields/retry@v3
|
uses: nick-fields/retry@v3
|
||||||
with:
|
with:
|
||||||
max_attempts: 20
|
max_attempts: 5
|
||||||
timeout_seconds: 200
|
timeout_seconds: 200
|
||||||
shell: bash
|
shell: bash
|
||||||
command: |
|
command: |
|
||||||
@@ -155,8 +250,9 @@ jobs:
|
|||||||
d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/
|
d=sherpa-onnx-nemo-transducer-giga-am-v2-russian-2025-04-19/
|
||||||
export GIT_LFS_SKIP_SMUDGE=1
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
rm -rf huggingface
|
||||||
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||||
mv -v $d/* ./huggingface
|
cp -av $d/* ./huggingface
|
||||||
cd huggingface
|
cd huggingface
|
||||||
git lfs track "*.onnx"
|
git lfs track "*.onnx"
|
||||||
git lfs track "*.wav"
|
git lfs track "*.wav"
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ to sherpa-onnx.
|
|||||||
The ASR models are for Russian speech recognition in this folder.
|
The ASR models are for Russian speech recognition in this folder.
|
||||||
|
|
||||||
Please see the license of the models at
|
Please see the license of the models at
|
||||||
https://github.com/salute-developers/GigaAM/blob/main/GigaAM%20License_NC.pdf
|
https://github.com/salute-developers/GigaAM/blob/main/LICENSE
|
||||||
|
|||||||
11
scripts/nemo/GigaAM/export-onnx-ctc-v2.py
Normal file → Executable file
11
scripts/nemo/GigaAM/export-onnx-ctc-v2.py
Normal file → Executable file
@@ -1,3 +1,4 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
import gigaam
|
import gigaam
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
@@ -27,7 +28,13 @@ def add_meta_data(filename: str, meta_data: dict[str, str]):
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
model_name = "v2_ctc"
|
model_name = "v2_ctc"
|
||||||
model = gigaam.load_model(model_name, fp16_encoder=False, use_flash=False, download_root=".")
|
model = gigaam.load_model(
|
||||||
|
model_name, fp16_encoder=False, use_flash=False, download_root="."
|
||||||
|
)
|
||||||
|
|
||||||
|
# use characters
|
||||||
|
# space is 0
|
||||||
|
# <blk> is the last token
|
||||||
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
||||||
for i, s in enumerate(model.cfg["labels"]):
|
for i, s in enumerate(model.cfg["labels"]):
|
||||||
f.write(f"{s} {i}\n")
|
f.write(f"{s} {i}\n")
|
||||||
@@ -53,5 +60,5 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -82,6 +82,9 @@ def main():
|
|||||||
model.load_state_dict(ckpt, strict=False)
|
model.load_state_dict(ckpt, strict=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# use characters
|
||||||
|
# space is 0
|
||||||
|
# <blk> is the last token
|
||||||
with open("tokens.txt", "w", encoding="utf-8") as f:
|
with open("tokens.txt", "w", encoding="utf-8") as f:
|
||||||
for i, t in enumerate(model.cfg.labels):
|
for i, t in enumerate(model.cfg.labels):
|
||||||
f.write(f"{t} {i}\n")
|
f.write(f"{t} {i}\n")
|
||||||
|
|||||||
158
scripts/nemo/GigaAM/export-onnx-rnnt-v2.py
Executable file
158
scripts/nemo/GigaAM/export-onnx-rnnt-v2.py
Executable file
@@ -0,0 +1,158 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gigaam
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
from gigaam.utils import onnx_converter
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
"""
|
||||||
|
==========Input==========
|
||||||
|
NodeArg(name='audio_signal', type='tensor(float)', shape=['batch_size', 64, 'seq_len'])
|
||||||
|
NodeArg(name='length', type='tensor(int64)', shape=['batch_size'])
|
||||||
|
==========Output==========
|
||||||
|
NodeArg(name='encoded', type='tensor(float)', shape=['batch_size', 768, 'Transposeencoded_dim_2'])
|
||||||
|
NodeArg(name='encoded_len', type='tensor(int32)', shape=['batch_size'])
|
||||||
|
|
||||||
|
==========Input==========
|
||||||
|
NodeArg(name='x', type='tensor(int32)', shape=[1, 1])
|
||||||
|
NodeArg(name='unused_x_len.1', type='tensor(int32)', shape=[1])
|
||||||
|
NodeArg(name='h.1', type='tensor(float)', shape=[1, 1, 320])
|
||||||
|
NodeArg(name='c.1', type='tensor(float)', shape=[1, 1, 320])
|
||||||
|
==========Output==========
|
||||||
|
NodeArg(name='dec', type='tensor(float)', shape=[1, 320, 1])
|
||||||
|
NodeArg(name='unused_x_len', type='tensor(int32)', shape=[1])
|
||||||
|
NodeArg(name='h', type='tensor(float)', shape=[1, 1, 320])
|
||||||
|
NodeArg(name='c', type='tensor(float)', shape=[1, 1, 320])
|
||||||
|
|
||||||
|
==========Input==========
|
||||||
|
NodeArg(name='enc', type='tensor(float)', shape=[1, 768, 1])
|
||||||
|
NodeArg(name='dec', type='tensor(float)', shape=[1, 320, 1])
|
||||||
|
==========Output==========
|
||||||
|
NodeArg(name='joint', type='tensor(float)', shape=[1, 1, 1, 34])
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
while len(model.metadata_props):
|
||||||
|
model.metadata_props.pop()
|
||||||
|
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = str(value)
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, m):
|
||||||
|
super().__init__()
|
||||||
|
self.m = m
|
||||||
|
|
||||||
|
def forward(self, audio_signal: Tensor, length: Tensor):
|
||||||
|
# https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py#L499
|
||||||
|
out, out_len = self.m.encoder(audio_signal, length)
|
||||||
|
|
||||||
|
return out, out_len.to(torch.int64)
|
||||||
|
|
||||||
|
def to_onnx(self, dir_path: str = "."):
|
||||||
|
onnx_converter(
|
||||||
|
model_name=f"{self.m.cfg.model_name}_encoder",
|
||||||
|
out_dir=dir_path,
|
||||||
|
module=self.m.encoder,
|
||||||
|
dynamic_axes=self.m.encoder.dynamic_axes(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, m):
|
||||||
|
super().__init__()
|
||||||
|
self.m = m
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, unused_x_len: Tensor, h: Tensor, c: Tensor):
|
||||||
|
# https://github.com/salute-developers/GigaAM/blob/main/gigaam/decoder.py#L110C17-L110C54
|
||||||
|
emb = self.m.head.decoder.embed(x)
|
||||||
|
g, (h, c) = self.m.head.decoder.lstm(emb.transpose(0, 1), (h, c))
|
||||||
|
return g.permute(1, 2, 0), unused_x_len + 1, h, c
|
||||||
|
|
||||||
|
def to_onnx(self, dir_path: str = "."):
|
||||||
|
label, hidden_h, hidden_c = self.m.head.decoder.input_example()
|
||||||
|
label = label.to(torch.int32)
|
||||||
|
label_len = torch.zeros(1, dtype=torch.int32)
|
||||||
|
|
||||||
|
onnx_converter(
|
||||||
|
model_name=f"{self.m.cfg.model_name}_decoder",
|
||||||
|
out_dir=dir_path,
|
||||||
|
module=self,
|
||||||
|
dynamic_axes=self.m.encoder.dynamic_axes(),
|
||||||
|
inputs=(label, label_len, hidden_h, hidden_c),
|
||||||
|
input_names=["x", "unused_x_len.1", "h.1", "c.1"],
|
||||||
|
output_names=["dec", "unused_x_len", "h", "c"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
model_name = "v2_rnnt"
|
||||||
|
model = gigaam.load_model(
|
||||||
|
model_name, fp16_encoder=False, use_flash=False, download_root="."
|
||||||
|
)
|
||||||
|
|
||||||
|
# use characters
|
||||||
|
# space is 0
|
||||||
|
# <blk> is the last token
|
||||||
|
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
||||||
|
for i, s in enumerate(model.cfg["labels"]):
|
||||||
|
f.write(f"{s} {i}\n")
|
||||||
|
f.write(f"<blk> {i+1}\n")
|
||||||
|
print("Saved to tokens.txt")
|
||||||
|
|
||||||
|
EncoderWrapper(model).to_onnx(".")
|
||||||
|
DecoderWrapper(model).to_onnx(".")
|
||||||
|
|
||||||
|
onnx_converter(
|
||||||
|
model_name=f"{model.cfg.model_name}_joint",
|
||||||
|
out_dir=".",
|
||||||
|
module=model.head.joint,
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
# vocab_size does not include the blank
|
||||||
|
# we will increase vocab_size by 1 in the c++ code
|
||||||
|
"vocab_size": model.cfg["head"]["decoder"]["num_classes"] - 1,
|
||||||
|
"pred_rnn_layers": model.cfg["head"]["decoder"]["pred_rnn_layers"],
|
||||||
|
"pred_hidden": model.cfg["head"]["decoder"]["pred_hidden"],
|
||||||
|
"normalize_type": "",
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"model_type": "EncDecRNNTBPEModel",
|
||||||
|
"version": "2",
|
||||||
|
"model_author": "https://github.com/salute-developers/GigaAM",
|
||||||
|
"license": "https://github.com/salute-developers/GigaAM/blob/main/LICENSE",
|
||||||
|
"language": "Russian",
|
||||||
|
"is_giga_am": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
add_meta_data(f"./{model_name}_encoder.onnx", meta_data)
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=f"./{model_name}_encoder.onnx",
|
||||||
|
model_output="./encoder.int8.onnx",
|
||||||
|
weight_type=QuantType.QUInt8,
|
||||||
|
)
|
||||||
|
os.rename(f"./{model_name}_decoder.onnx", "decoder.onnx")
|
||||||
|
os.rename(f"./{model_name}_joint.onnx", "joiner.onnx")
|
||||||
|
os.remove(f"./{model_name}_encoder.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
scripts/nemo/GigaAM/export-onnx-rnnt.py
Normal file → Executable file
5
scripts/nemo/GigaAM/export-onnx-rnnt.py
Normal file → Executable file
@@ -83,6 +83,7 @@ def main():
|
|||||||
model.load_state_dict(ckpt, strict=False)
|
model.load_state_dict(ckpt, strict=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# use bpe
|
||||||
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
||||||
for i, s in enumerate(model.joint.vocabulary):
|
for i, s in enumerate(model.joint.vocabulary):
|
||||||
f.write(f"{s} {i}\n")
|
f.write(f"{s} {i}\n")
|
||||||
@@ -94,7 +95,9 @@ def main():
|
|||||||
model.joint.export("joiner.onnx")
|
model.joint.export("joiner.onnx")
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"vocab_size": model.decoder.vocab_size, # not including the blank
|
# not including the blank
|
||||||
|
# we increase vocab_size in the C++ code
|
||||||
|
"vocab_size": model.decoder.vocab_size,
|
||||||
"pred_rnn_layers": model.decoder.pred_rnn_layers,
|
"pred_rnn_layers": model.decoder.pred_rnn_layers,
|
||||||
"pred_hidden": model.decoder.pred_hidden,
|
"pred_hidden": model.decoder.pred_hidden,
|
||||||
"normalize_type": "",
|
"normalize_type": "",
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ set -ex
|
|||||||
function install_gigaam() {
|
function install_gigaam() {
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
python3 get-pip.py
|
python3 get-pip.py
|
||||||
|
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa
|
||||||
|
|
||||||
BRANCH='main'
|
BRANCH='main'
|
||||||
python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam
|
python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam
|
||||||
|
|
||||||
python3 -m pip install -qq kaldi-native-fbank
|
python3 -m pip install -qq kaldi-native-fbank
|
||||||
|
pip install numpy==1.26.4
|
||||||
}
|
}
|
||||||
|
|
||||||
function download_files() {
|
function download_files() {
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ function install_nemo() {
|
|||||||
|
|
||||||
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa
|
pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa
|
||||||
pip install -qq ipython
|
pip install -qq ipython
|
||||||
|
|
||||||
# sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython
|
# sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython
|
||||||
|
|||||||
29
scripts/nemo/GigaAM/run-rnnt-v2.sh
Executable file
29
scripts/nemo/GigaAM/run-rnnt-v2.sh
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
function install_gigaam() {
|
||||||
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
|
python3 get-pip.py
|
||||||
|
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa
|
||||||
|
|
||||||
|
BRANCH='main'
|
||||||
|
python3 -m pip install git+https://github.com/salute-developers/GigaAM.git@$BRANCH#egg=gigaam
|
||||||
|
|
||||||
|
python3 -m pip install -qq kaldi-native-fbank
|
||||||
|
pip install numpy==1.26.4
|
||||||
|
}
|
||||||
|
|
||||||
|
function download_files() {
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/tmp-files/resolve/main/GigaAM/example.wav
|
||||||
|
curl -SL -O https://github.com/salute-developers/GigaAM/blob/main/LICENSE
|
||||||
|
}
|
||||||
|
|
||||||
|
install_gigaam
|
||||||
|
download_files
|
||||||
|
|
||||||
|
python3 ./export-onnx-rnnt-v2.py
|
||||||
|
ls -lh
|
||||||
|
python3 ./test-onnx-rnnt.py
|
||||||
@@ -9,7 +9,7 @@ function install_nemo() {
|
|||||||
|
|
||||||
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
pip install torch==2.4.0 torchaudio==2.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
pip install -qq wget text-unidecode matplotlib>=3.3.2 onnx onnxruntime pybind11 Cython einops kaldi-native-fbank soundfile librosa
|
pip install -qq wget text-unidecode "matplotlib>=3.3.2" onnx onnxruntime==1.17.1 pybind11 Cython einops kaldi-native-fbank soundfile librosa
|
||||||
pip install -qq ipython
|
pip install -qq ipython
|
||||||
|
|
||||||
# sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython
|
# sudo apt-get install -q -y sox libsndfile1 ffmpeg python3-pip ipython
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def create_fbank():
|
|||||||
opts.frame_opts.window_type = "hann"
|
opts.frame_opts.window_type = "hann"
|
||||||
|
|
||||||
# Even though GigaAM uses 400 for fft, here we use 512
|
# Even though GigaAM uses 400 for fft, here we use 512
|
||||||
# since kaldi-native-fbank only support fft for power of 2.
|
# since kaldi-native-fbank only supports fft for power of 2.
|
||||||
opts.frame_opts.round_to_power_of_two = True
|
opts.frame_opts.round_to_power_of_two = True
|
||||||
|
|
||||||
opts.mel_opts.low_freq = 0
|
opts.mel_opts.low_freq = 0
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ def create_fbank():
|
|||||||
opts.frame_opts.window_type = "hann"
|
opts.frame_opts.window_type = "hann"
|
||||||
|
|
||||||
# Even though GigaAM uses 400 for fft, here we use 512
|
# Even though GigaAM uses 400 for fft, here we use 512
|
||||||
# since kaldi-native-fbank only support fft for power of 2.
|
# since kaldi-native-fbank only supports fft for power of 2.
|
||||||
opts.frame_opts.round_to_power_of_two = True
|
opts.frame_opts.round_to_power_of_two = True
|
||||||
|
|
||||||
opts.mel_opts.low_freq = 0
|
opts.mel_opts.low_freq = 0
|
||||||
@@ -166,12 +166,7 @@ class OnnxModel:
|
|||||||
target = torch.tensor([[token]], dtype=torch.int32).numpy()
|
target = torch.tensor([[token]], dtype=torch.int32).numpy()
|
||||||
target_len = torch.tensor([1], dtype=torch.int32).numpy()
|
target_len = torch.tensor([1], dtype=torch.int32).numpy()
|
||||||
|
|
||||||
(
|
(decoder_out, decoder_out_length, state0_next, state1_next,) = self.decoder.run(
|
||||||
decoder_out,
|
|
||||||
decoder_out_length,
|
|
||||||
state0_next,
|
|
||||||
state1_next,
|
|
||||||
) = self.decoder.run(
|
|
||||||
[
|
[
|
||||||
self.decoder.get_outputs()[0].name,
|
self.decoder.get_outputs()[0].name,
|
||||||
self.decoder.get_outputs()[1].name,
|
self.decoder.get_outputs()[1].name,
|
||||||
@@ -213,8 +208,12 @@ def main():
|
|||||||
id2token = dict()
|
id2token = dict()
|
||||||
with open("./tokens.txt", encoding="utf-8") as f:
|
with open("./tokens.txt", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
t, idx = line.split()
|
fields = line.split()
|
||||||
id2token[int(idx)] = t
|
if len(fields) == 1:
|
||||||
|
id2token[int(fields[0])] = " "
|
||||||
|
else:
|
||||||
|
t, idx = fields
|
||||||
|
id2token[int(idx)] = t
|
||||||
|
|
||||||
fbank = create_fbank()
|
fbank = create_fbank()
|
||||||
audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True)
|
audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user