Export nvidia/canary-180m-flash to sherpa-onnx (#2272)
This commit is contained in:
132
.github/workflows/export-nemo-canary-180m-flash.yaml
vendored
Normal file
132
.github/workflows/export-nemo-canary-180m-flash.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: export-nemo-canary-180m-flash
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- export-nemo-canary
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: export-nemo-canary-180m-flash-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
export-nemo-canary-180m-flash:
|
||||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||
name: parakeet nemo canary 180m flash
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Run
|
||||
shell: bash
|
||||
run: |
|
||||
cd scripts/nemo/canary
|
||||
./run_180m_flash.sh
|
||||
|
||||
ls -lh *.onnx
|
||||
mv -v *.onnx ../../..
|
||||
mv -v tokens.txt ../../..
|
||||
mv de.wav ../../../
|
||||
mv en.wav ../../../
|
||||
|
||||
- name: Collect files (fp32)
|
||||
shell: bash
|
||||
run: |
|
||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
|
||||
mkdir -p $d
|
||||
cp encoder.onnx $d
|
||||
cp decoder.onnx $d
|
||||
cp tokens.txt $d
|
||||
|
||||
mkdir $d/test_wavs
|
||||
cp de.wav $d/test_wavs
|
||||
cp en.wav $d/test_wavs
|
||||
|
||||
tar cjfv $d.tar.bz2 $d
|
||||
|
||||
- name: Collect files (int8)
|
||||
shell: bash
|
||||
run: |
|
||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||
mkdir -p $d
|
||||
cp encoder.int8.onnx $d
|
||||
cp decoder.fp16.onnx $d
|
||||
cp tokens.txt $d
|
||||
|
||||
mkdir $d/test_wavs
|
||||
cp de.wav $d/test_wavs
|
||||
cp en.wav $d/test_wavs
|
||||
|
||||
tar cjfv $d.tar.bz2 $d
|
||||
|
||||
- name: Collect files (fp16)
|
||||
shell: bash
|
||||
run: |
|
||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
|
||||
mkdir -p $d
|
||||
cp encoder.fp16.onnx $d
|
||||
cp decoder.fp16.onnx $d
|
||||
cp tokens.txt $d
|
||||
|
||||
mkdir $d/test_wavs
|
||||
cp de.wav $d/test_wavs
|
||||
cp en.wav $d/test_wavs
|
||||
|
||||
tar cjfv $d.tar.bz2 $d
|
||||
|
||||
- name: Publish to huggingface
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
max_attempts: 20
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
models=(
|
||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
|
||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
|
||||
)
|
||||
|
||||
for m in ${models[@]}; do
|
||||
rm -rf huggingface
|
||||
export GIT_LFS_SKIP_SMUDGE=1
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m huggingface
|
||||
cp -av $m/* huggingface
|
||||
cd huggingface
|
||||
git lfs track "*.onnx"
|
||||
git lfs track "*.wav"
|
||||
git status
|
||||
git add .
|
||||
git status
|
||||
git commit -m "first commit"
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m main
|
||||
cd ..
|
||||
done
|
||||
|
||||
- name: Release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
file: ./*.tar.bz2
|
||||
overwrite: true
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: asr-models
|
||||
289
scripts/nemo/canary/export_onnx_180m_flash.py
Executable file
289
scripts/nemo/canary/export_onnx_180m_flash.py
Executable file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import nemo
|
||||
import onnxmltools
|
||||
import torch
|
||||
from nemo.collections.common.parts import NEG_INF
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
"""
|
||||
NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED :
|
||||
Could not find an implementation for Trilu(14) node with name '/Trilu'
|
||||
|
||||
See also https://github.com/microsoft/onnxruntime/issues/16189#issuecomment-1722219631
|
||||
|
||||
So we use fixed_form_attention_mask() to replace
|
||||
the original form_attention_mask()
|
||||
"""
|
||||
|
||||
|
||||
def fixed_form_attention_mask(input_mask, diagonal=None):
|
||||
"""
|
||||
Fixed: Build attention mask with optional masking of future tokens we forbid
|
||||
to attend to (e.g. as it is in Transformer decoder).
|
||||
|
||||
Args:
|
||||
input_mask: binary mask of size B x L with 1s corresponding to valid
|
||||
tokens and 0s corresponding to padding tokens
|
||||
diagonal: diagonal where triangular future mask starts
|
||||
None -- do not mask anything
|
||||
0 -- regular translation or language modeling future masking
|
||||
1 -- query stream masking as in XLNet architecture
|
||||
Returns:
|
||||
attention_mask: mask of size B x 1 x L x L with 0s corresponding to
|
||||
tokens we plan to attend to and -10000 otherwise
|
||||
"""
|
||||
|
||||
if input_mask is None:
|
||||
return None
|
||||
attn_shape = (1, input_mask.shape[1], input_mask.shape[1])
|
||||
attn_mask = input_mask.to(dtype=bool).unsqueeze(1)
|
||||
if diagonal is not None:
|
||||
future_mask = torch.tril(
|
||||
torch.ones(
|
||||
attn_shape,
|
||||
dtype=torch.int64, # it was torch.bool
|
||||
# but onnxruntime does not support torch.int32 or torch.bool
|
||||
# in torch.tril
|
||||
device=input_mask.device,
|
||||
),
|
||||
diagonal,
|
||||
).bool()
|
||||
attn_mask = attn_mask & future_mask
|
||||
attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF
|
||||
return attention_mask.unsqueeze(1)
|
||||
|
||||
|
||||
nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
|
||||
|
||||
from nemo.collections.asr.models import EncDecMultiTaskModel
|
||||
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
|
||||
|
||||
def lens_to_mask(lens, max_length):
|
||||
"""
|
||||
Create a mask from a tensor of lengths.
|
||||
"""
|
||||
batch_size = lens.shape[0]
|
||||
arange = torch.arange(max_length, device=lens.device)
|
||||
mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
|
||||
return mask
|
||||
|
||||
|
||||
class EncoderWrapper(torch.nn.Module):
|
||||
def __init__(self, m):
|
||||
super().__init__()
|
||||
self.encoder = m.encoder
|
||||
self.encoder_decoder_proj = m.encoder_decoder_proj
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_len: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: (N, T, C)
|
||||
x_len: (N,)
|
||||
Returns:
|
||||
- enc_states: (N, T, C)
|
||||
- encoded_len: (N,)
|
||||
- enc_mask: (N, T)
|
||||
"""
|
||||
x = x.permute(0, 2, 1)
|
||||
# x: (N, C, T)
|
||||
encoded, encoded_len = self.encoder(audio_signal=x, length=x_len)
|
||||
|
||||
enc_states = encoded.permute(0, 2, 1)
|
||||
|
||||
enc_states = self.encoder_decoder_proj(enc_states)
|
||||
|
||||
enc_mask = lens_to_mask(encoded_len, enc_states.shape[1])
|
||||
|
||||
return enc_states, encoded_len, enc_mask
|
||||
|
||||
|
||||
class DecoderWrapper(torch.nn.Module):
|
||||
def __init__(self, m):
|
||||
super().__init__()
|
||||
self.decoder = m.transf_decoder
|
||||
self.log_softmax = m.log_softmax
|
||||
|
||||
# We use only greedy search, so there is no need to compute log_softmax
|
||||
self.log_softmax.mlp.log_softmax = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
decoder_mems_list_0: torch.Tensor,
|
||||
decoder_mems_list_1: torch.Tensor,
|
||||
decoder_mems_list_2: torch.Tensor,
|
||||
decoder_mems_list_3: torch.Tensor,
|
||||
decoder_mems_list_4: torch.Tensor,
|
||||
decoder_mems_list_5: torch.Tensor,
|
||||
enc_states: torch.Tensor,
|
||||
enc_mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
decoder_input_ids: (N, num_tokens), torch.int32
|
||||
decoder_mems_list_i: (N, num_tokens, 1024)
|
||||
enc_states: (N, T, 1024)
|
||||
enc_mask: (N, T)
|
||||
Returns:
|
||||
- logits: (N, 1, vocab_size)
|
||||
- decoder_mems_list_i: (N, num_tokens_2, 1024)
|
||||
"""
|
||||
pos = decoder_input_ids[0][-1].item()
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
|
||||
decoder_hidden_states = self.decoder.embedding.forward(
|
||||
decoder_input_ids, start_pos=pos
|
||||
)
|
||||
decoder_input_mask = torch.ones_like(decoder_input_ids).float()
|
||||
|
||||
decoder_mems_list = self.decoder.decoder.forward(
|
||||
decoder_hidden_states,
|
||||
decoder_input_mask,
|
||||
enc_states,
|
||||
enc_mask,
|
||||
[
|
||||
decoder_mems_list_0,
|
||||
decoder_mems_list_1,
|
||||
decoder_mems_list_2,
|
||||
decoder_mems_list_3,
|
||||
decoder_mems_list_4,
|
||||
decoder_mems_list_5,
|
||||
],
|
||||
return_mems=True,
|
||||
)
|
||||
logits = self.log_softmax(hidden_states=decoder_mems_list[-1][:, -1:])
|
||||
|
||||
return logits, decoder_mems_list
|
||||
|
||||
|
||||
def export_encoder(canary_model):
|
||||
encoder = EncoderWrapper(canary_model)
|
||||
x = torch.rand(1, 4000, 128)
|
||||
x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
|
||||
encoder_filename = "encoder.onnx"
|
||||
torch.onnx.export(
|
||||
encoder,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
input_names=["x", "x_len"],
|
||||
output_names=["enc_states", "enc_len", "enc_mask"],
|
||||
opset_version=14,
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_len": {0: "N"},
|
||||
"enc_states": {0: "N", 1: "T"},
|
||||
"enc_len": {0: "N"},
|
||||
"enc_mask": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def export_decoder(canary_model):
|
||||
decoder = DecoderWrapper(canary_model)
|
||||
decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
|
||||
|
||||
decoder_mems_list_0 = torch.zeros(1, 1, 1024)
|
||||
decoder_mems_list_1 = torch.zeros(1, 1, 1024)
|
||||
decoder_mems_list_2 = torch.zeros(1, 1, 1024)
|
||||
decoder_mems_list_3 = torch.zeros(1, 1, 1024)
|
||||
decoder_mems_list_4 = torch.zeros(1, 1, 1024)
|
||||
decoder_mems_list_5 = torch.zeros(1, 1, 1024)
|
||||
|
||||
enc_states = torch.zeros(1, 1000, 1024)
|
||||
enc_mask = torch.ones(1, 1000).bool()
|
||||
|
||||
torch.onnx.export(
|
||||
decoder,
|
||||
(
|
||||
decoder_input_ids,
|
||||
decoder_mems_list_0,
|
||||
decoder_mems_list_1,
|
||||
decoder_mems_list_2,
|
||||
decoder_mems_list_3,
|
||||
decoder_mems_list_4,
|
||||
decoder_mems_list_5,
|
||||
enc_states,
|
||||
enc_mask,
|
||||
),
|
||||
"decoder.onnx",
|
||||
opset_version=14,
|
||||
input_names=[
|
||||
"decoder_input_ids",
|
||||
"decoder_mems_list_0",
|
||||
"decoder_mems_list_1",
|
||||
"decoder_mems_list_2",
|
||||
"decoder_mems_list_3",
|
||||
"decoder_mems_list_4",
|
||||
"decoder_mems_list_5",
|
||||
"enc_states",
|
||||
"enc_mask",
|
||||
],
|
||||
output_names=[
|
||||
"logits",
|
||||
"next_decoder_mem_list_0",
|
||||
"next_decoder_mem_list_1",
|
||||
"next_decoder_mem_list_2",
|
||||
"next_decoder_mem_list_3",
|
||||
"next_decoder_mem_list_4",
|
||||
"next_decoder_mem_list_5",
|
||||
],
|
||||
dynamic_axes={
|
||||
"decoder_input_ids": {1: "num_tokens"},
|
||||
"decoder_mems_list_0": {1: "num_tokens"},
|
||||
"decoder_mems_list_1": {1: "num_tokens"},
|
||||
"decoder_mems_list_2": {1: "num_tokens"},
|
||||
"decoder_mems_list_3": {1: "num_tokens"},
|
||||
"decoder_mems_list_4": {1: "num_tokens"},
|
||||
"decoder_mems_list_5": {1: "num_tokens"},
|
||||
"enc_states": {1: "T"},
|
||||
"enc_mask": {1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def export_tokens(canary_model):
|
||||
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
||||
for i in range(canary_model.tokenizer.vocab_size):
|
||||
s = canary_model.tokenizer.ids_to_text([i])
|
||||
f.write(f"{s} {i}\n")
|
||||
print("Saved to tokens.txt")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
|
||||
export_tokens(canary_model)
|
||||
export_encoder(canary_model)
|
||||
export_decoder(canary_model)
|
||||
|
||||
for m in ["encoder", "decoder"]:
|
||||
if m == "encoder":
|
||||
# we don't quantize the decoder with int8 since the accuracy drops
|
||||
quantize_dynamic(
|
||||
model_input=f"./{m}.onnx",
|
||||
model_output=f"./{m}.int8.onnx",
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
|
||||
|
||||
os.system("ls -lh *.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
131
scripts/nemo/canary/run_180m_flash.sh
Executable file
131
scripts/nemo/canary/run_180m_flash.sh
Executable file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
set -ex
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/de.wav
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/en.wav
|
||||
|
||||
pip install \
|
||||
nemo_toolkit['asr'] \
|
||||
"numpy<2" \
|
||||
ipython \
|
||||
kaldi-native-fbank \
|
||||
librosa \
|
||||
onnx==1.17.0 \
|
||||
onnxmltools \
|
||||
onnxruntime==1.17.1 \
|
||||
soundfile
|
||||
|
||||
python3 ./export_onnx_180m_flash.py
|
||||
ls -lh *.onnx
|
||||
|
||||
|
||||
log "-----fp32------"
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.onnx \
|
||||
--decoder ./decoder.onnx \
|
||||
--source-lang en \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.onnx \
|
||||
--decoder ./decoder.onnx \
|
||||
--source-lang en \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.onnx \
|
||||
--decoder ./decoder.onnx \
|
||||
--source-lang de \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.onnx \
|
||||
--decoder ./decoder.onnx \
|
||||
--source-lang de \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
|
||||
log "-----int8------"
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang en \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang en \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang de \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang de \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
log "-----fp16------"
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang en \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang en \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang de \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang de \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
299
scripts/nemo/canary/test_180m_flash.py
Executable file
299
scripts/nemo/canary/test_180m_flash.py
Executable file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import librosa
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--encoder", type=str, required=True, help="Path to encoder.onnx"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder", type=str, required=True, help="Path to decoder.onnx"
|
||||
)
|
||||
|
||||
parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt")
|
||||
|
||||
parser.add_argument(
|
||||
"--source-lang",
|
||||
type=str,
|
||||
help="Language of the input wav. Valid values are: en, de, es, fr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-lang",
|
||||
type=str,
|
||||
help="Language of the recognition result. Valid values are: en, de, es, fr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-pnc",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 to enable cases and punctuations. 0 to disable that",
|
||||
)
|
||||
|
||||
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def display(sess, model):
|
||||
print(f"=========={model} Input==========")
|
||||
for i in sess.get_inputs():
|
||||
print(i)
|
||||
print(f"=========={model }Output==========")
|
||||
for i in sess.get_outputs():
|
||||
print(i)
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
):
|
||||
self.init_encoder(encoder)
|
||||
display(self.encoder, "encoder")
|
||||
|
||||
self.init_decoder(decoder)
|
||||
display(self.decoder, "decoder")
|
||||
|
||||
def init_encoder(self, encoder):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
# self.normalize_type = meta["normalize_type"]
|
||||
self.normalize_type = "per_feature"
|
||||
print(meta)
|
||||
|
||||
def init_decoder(self, decoder):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_encoder(self, x: np.ndarray, x_lens: np.ndarray):
|
||||
"""
|
||||
Args:
|
||||
x: (N, T, C), np.float
|
||||
x_lens: (N,), np.int64
|
||||
Returns:
|
||||
enc_states: (N, T, C)
|
||||
enc_lens: (N,), np.int64
|
||||
enc_masks: (N, T), np.bool
|
||||
"""
|
||||
enc_states, enc_lens, enc_masks = self.encoder.run(
|
||||
[
|
||||
self.encoder.get_outputs()[0].name,
|
||||
self.encoder.get_outputs()[1].name,
|
||||
self.encoder.get_outputs()[2].name,
|
||||
],
|
||||
{
|
||||
self.encoder.get_inputs()[0].name: x,
|
||||
self.encoder.get_inputs()[1].name: x_lens,
|
||||
},
|
||||
)
|
||||
return enc_states, enc_lens, enc_masks
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
decoder_input_ids: np.ndarray,
|
||||
decoder_mems_list: List[np.ndarray],
|
||||
enc_states: np.ndarray,
|
||||
enc_mask: np.ndarray,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
decoder_input_ids: (N, num_tokens), int32
|
||||
decoder_mems_list: a list of tensors, each of which is (N, num_tokens, C)
|
||||
enc_states: (N, T, C), float
|
||||
enc_mask: (N, T), bool
|
||||
Returns:
|
||||
logits: (1, 1, vocab_size), float
|
||||
new_decoder_mems_list:
|
||||
"""
|
||||
(logits, *new_decoder_mems_list) = self.decoder.run(
|
||||
[
|
||||
self.decoder.get_outputs()[0].name,
|
||||
self.decoder.get_outputs()[1].name,
|
||||
self.decoder.get_outputs()[2].name,
|
||||
self.decoder.get_outputs()[3].name,
|
||||
self.decoder.get_outputs()[4].name,
|
||||
self.decoder.get_outputs()[5].name,
|
||||
self.decoder.get_outputs()[6].name,
|
||||
],
|
||||
{
|
||||
self.decoder.get_inputs()[0].name: decoder_input_ids,
|
||||
self.decoder.get_inputs()[1].name: decoder_mems_list[0],
|
||||
self.decoder.get_inputs()[2].name: decoder_mems_list[1],
|
||||
self.decoder.get_inputs()[3].name: decoder_mems_list[2],
|
||||
self.decoder.get_inputs()[4].name: decoder_mems_list[3],
|
||||
self.decoder.get_inputs()[5].name: decoder_mems_list[4],
|
||||
self.decoder.get_inputs()[6].name: decoder_mems_list[5],
|
||||
self.decoder.get_inputs()[7].name: enc_states,
|
||||
self.decoder.get_inputs()[8].name: enc_mask,
|
||||
},
|
||||
)
|
||||
return logits, new_decoder_mems_list
|
||||
|
||||
|
||||
def create_fbank():
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.remove_dc_offset = False
|
||||
opts.frame_opts.window_type = "hann"
|
||||
|
||||
opts.mel_opts.low_freq = 0
|
||||
opts.mel_opts.num_bins = 128
|
||||
|
||||
opts.mel_opts.is_librosa = True
|
||||
|
||||
fbank = knf.OnlineFbank(opts)
|
||||
return fbank
|
||||
|
||||
|
||||
def compute_features(audio, fbank):
|
||||
assert len(audio.shape) == 1, audio.shape
|
||||
fbank.accept_waveform(16000, audio)
|
||||
ans = []
|
||||
processed = 0
|
||||
while processed < fbank.num_frames_ready:
|
||||
ans.append(np.array(fbank.get_frame(processed)))
|
||||
processed += 1
|
||||
ans = np.stack(ans)
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert Path(args.encoder).is_file(), args.encoder
|
||||
assert Path(args.decoder).is_file(), args.decoder
|
||||
assert Path(args.tokens).is_file(), args.tokens
|
||||
assert Path(args.wav).is_file(), args.wav
|
||||
|
||||
print(vars(args))
|
||||
|
||||
id2token = dict()
|
||||
token2id = dict()
|
||||
with open(args.tokens, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
fields = line.split()
|
||||
if len(fields) == 2:
|
||||
t, idx = fields[0], int(fields[1])
|
||||
if line[0] == " ":
|
||||
t = " " + t
|
||||
else:
|
||||
t = " "
|
||||
idx = int(fields[0])
|
||||
|
||||
id2token[idx] = t
|
||||
token2id[t] = idx
|
||||
|
||||
model = OnnxModel(args.encoder, args.decoder)
|
||||
|
||||
fbank = create_fbank()
|
||||
|
||||
start = time.time()
|
||||
audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
if sample_rate != 16000:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sample_rate,
|
||||
target_sr=16000,
|
||||
)
|
||||
sample_rate = 16000
|
||||
|
||||
features = compute_features(audio, fbank)
|
||||
if model.normalize_type != "":
|
||||
assert model.normalize_type == "per_feature", model.normalize_type
|
||||
mean = features.mean(axis=1, keepdims=True)
|
||||
stddev = features.std(axis=1, keepdims=True) + 1e-5
|
||||
features = (features - mean) / stddev
|
||||
|
||||
features = np.expand_dims(features, axis=0)
|
||||
# features.shape: (1, 291, 128)
|
||||
|
||||
features_len = np.array([features.shape[1]], dtype=np.int64)
|
||||
|
||||
enc_states, _, enc_masks = model.run_encoder(features, features_len)
|
||||
|
||||
decoder_input_ids = []
|
||||
decoder_input_ids.append(token2id["<|startofcontext|>"])
|
||||
decoder_input_ids.append(token2id["<|startoftranscript|>"])
|
||||
decoder_input_ids.append(token2id["<|emo:undefined|>"])
|
||||
if args.source_lang in ("en", "es", "de", "fr"):
|
||||
decoder_input_ids.append(token2id[f"<|{args.source_lang}|>"])
|
||||
else:
|
||||
decoder_input_ids.append(token2id[f"<|en|>"])
|
||||
|
||||
if args.target_lang in ("en", "es", "de", "fr"):
|
||||
decoder_input_ids.append(token2id[f"<|{args.target_lang}|>"])
|
||||
else:
|
||||
decoder_input_ids.append(token2id[f"<|en|>"])
|
||||
|
||||
if args.use_pnc:
|
||||
decoder_input_ids.append(token2id[f"<|pnc|>"])
|
||||
else:
|
||||
decoder_input_ids.append(token2id[f"<|nopnc|>"])
|
||||
|
||||
decoder_input_ids.append(token2id[f"<|noitn|>"])
|
||||
decoder_input_ids.append(token2id["<|notimestamp|>"])
|
||||
decoder_input_ids.append(token2id["<|nodiarize|>"])
|
||||
|
||||
decoder_input_ids.append(0)
|
||||
|
||||
decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
|
||||
|
||||
logits, decoder_mems_list = model.run_decoder(
|
||||
np.array([decoder_input_ids], dtype=np.int32),
|
||||
decoder_mems_list,
|
||||
enc_states,
|
||||
enc_masks,
|
||||
)
|
||||
tokens = [logits.argmax()]
|
||||
print("decoder_input_ids", decoder_input_ids)
|
||||
eos = token2id["<|endoftext|>"]
|
||||
|
||||
for i in range(1, 200):
|
||||
decoder_input_ids = [tokens[-1], i]
|
||||
logits, decoder_mems_list = model.run_decoder(
|
||||
np.array([decoder_input_ids], dtype=np.int32),
|
||||
decoder_mems_list,
|
||||
enc_states,
|
||||
enc_masks,
|
||||
)
|
||||
t = logits.argmax()
|
||||
if t == eos:
|
||||
break
|
||||
tokens.append(t)
|
||||
print("len(tokens)", len(tokens))
|
||||
print("tokens", tokens)
|
||||
text = "".join([id2token[i] for i in tokens])
|
||||
print("text:", text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user