Add meta data to NeMo canary ONNX models (#2351)
This commit is contained in:
@@ -62,22 +62,7 @@ jobs:
|
|||||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||||
mkdir -p $d
|
mkdir -p $d
|
||||||
cp encoder.int8.onnx $d
|
cp encoder.int8.onnx $d
|
||||||
cp decoder.fp16.onnx $d
|
cp decoder.int8.onnx $d
|
||||||
cp tokens.txt $d
|
|
||||||
|
|
||||||
mkdir $d/test_wavs
|
|
||||||
cp de.wav $d/test_wavs
|
|
||||||
cp en.wav $d/test_wavs
|
|
||||||
|
|
||||||
tar cjfv $d.tar.bz2 $d
|
|
||||||
|
|
||||||
- name: Collect files (fp16)
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
|
|
||||||
mkdir -p $d
|
|
||||||
cp encoder.fp16.onnx $d
|
|
||||||
cp decoder.fp16.onnx $d
|
|
||||||
cp tokens.txt $d
|
cp tokens.txt $d
|
||||||
|
|
||||||
mkdir $d/test_wavs
|
mkdir $d/test_wavs
|
||||||
@@ -101,7 +86,6 @@ jobs:
|
|||||||
models=(
|
models=(
|
||||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
|
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
|
||||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for m in ${models[@]}; do
|
for m in ${models[@]}; do
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
<|en|>
|
||||||
|
<|pnc|>
|
||||||
|
<|noitn|>
|
||||||
|
<|nodiarize|>
|
||||||
|
<|notimestamp|>
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import nemo
|
import nemo
|
||||||
import onnxmltools
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
from nemo.collections.common.parts import NEG_INF
|
from nemo.collections.common.parts import NEG_INF
|
||||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
|
|||||||
from nemo.collections.asr.models import EncDecMultiTaskModel
|
from nemo.collections.asr.models import EncDecMultiTaskModel
|
||||||
|
|
||||||
|
|
||||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
|
|
||||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
while len(model.metadata_props):
|
||||||
|
model.metadata_props.pop()
|
||||||
|
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = str(value)
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
def lens_to_mask(lens, max_length):
|
def lens_to_mask(lens, max_length):
|
||||||
@@ -222,7 +244,7 @@ def export_decoder(canary_model):
|
|||||||
),
|
),
|
||||||
"decoder.onnx",
|
"decoder.onnx",
|
||||||
dynamo=True,
|
dynamo=True,
|
||||||
opset_version=18,
|
opset_version=14,
|
||||||
external_data=False,
|
external_data=False,
|
||||||
input_names=[
|
input_names=[
|
||||||
"decoder_input_ids",
|
"decoder_input_ids",
|
||||||
@@ -269,6 +291,29 @@ def export_tokens(canary_model):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
|
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
|
||||||
|
canary_model.eval()
|
||||||
|
|
||||||
|
preprocessor = canary_model.cfg["preprocessor"]
|
||||||
|
sample_rate = preprocessor["sample_rate"]
|
||||||
|
normalize_type = preprocessor["normalize"]
|
||||||
|
window_size = preprocessor["window_size"] # ms
|
||||||
|
window_stride = preprocessor["window_stride"] # ms
|
||||||
|
window = preprocessor["window"]
|
||||||
|
features = preprocessor["features"]
|
||||||
|
n_fft = preprocessor["n_fft"]
|
||||||
|
vocab_size = canary_model.tokenizer.vocab_size # 5248
|
||||||
|
|
||||||
|
subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
|
||||||
|
|
||||||
|
assert sample_rate == 16000, sample_rate
|
||||||
|
assert normalize_type == "per_feature", normalize_type
|
||||||
|
assert window_size == 0.025, window_size
|
||||||
|
assert window_stride == 0.01, window_stride
|
||||||
|
assert window == "hann", window
|
||||||
|
assert features == 128, features
|
||||||
|
assert n_fft == 512, n_fft
|
||||||
|
assert subsampling_factor == 8, subsampling_factor
|
||||||
|
|
||||||
export_tokens(canary_model)
|
export_tokens(canary_model)
|
||||||
export_encoder(canary_model)
|
export_encoder(canary_model)
|
||||||
export_decoder(canary_model)
|
export_decoder(canary_model)
|
||||||
@@ -280,7 +325,32 @@ def main():
|
|||||||
weight_type=QuantType.QUInt8,
|
weight_type=QuantType.QUInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
|
meta_data = {
|
||||||
|
"vocab_size": vocab_size,
|
||||||
|
"normalize_type": normalize_type,
|
||||||
|
"subsampling_factor": subsampling_factor,
|
||||||
|
"model_type": "EncDecMultiTaskModel",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "NeMo",
|
||||||
|
"url": "https://huggingface.co/nvidia/canary-180m-flash",
|
||||||
|
"feat_dim": features,
|
||||||
|
}
|
||||||
|
|
||||||
|
add_meta_data("encoder.onnx", meta_data)
|
||||||
|
add_meta_data("encoder.int8.onnx", meta_data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
|
||||||
|
|
||||||
|
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
|
||||||
|
Unsupported model IR version: 10, max supported IR version: 9
|
||||||
|
"""
|
||||||
|
for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
|
||||||
|
model = onnx.load(filename)
|
||||||
|
print("old", model.ir_version)
|
||||||
|
model.ir_version = 9
|
||||||
|
print("new", model.ir_version)
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
os.system("ls -lh *.onnx")
|
os.system("ls -lh *.onnx")
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ pip install \
|
|||||||
kaldi-native-fbank \
|
kaldi-native-fbank \
|
||||||
librosa \
|
librosa \
|
||||||
onnx==1.17.0 \
|
onnx==1.17.0 \
|
||||||
onnxmltools \
|
|
||||||
onnxruntime==1.17.1 \
|
onnxruntime==1.17.1 \
|
||||||
|
onnxscript \
|
||||||
soundfile
|
soundfile
|
||||||
|
|
||||||
python3 ./export_onnx_180m_flash.py
|
python3 ./export_onnx_180m_flash.py
|
||||||
@@ -66,7 +66,7 @@ log "-----int8------"
|
|||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
python3 ./test_180m_flash.py \
|
||||||
--encoder ./encoder.int8.onnx \
|
--encoder ./encoder.int8.onnx \
|
||||||
--decoder ./decoder.fp16.onnx \
|
--decoder ./decoder.int8.onnx \
|
||||||
--source-lang en \
|
--source-lang en \
|
||||||
--target-lang en \
|
--target-lang en \
|
||||||
--tokens ./tokens.txt \
|
--tokens ./tokens.txt \
|
||||||
@@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \
|
|||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
python3 ./test_180m_flash.py \
|
||||||
--encoder ./encoder.int8.onnx \
|
--encoder ./encoder.int8.onnx \
|
||||||
--decoder ./decoder.fp16.onnx \
|
--decoder ./decoder.int8.onnx \
|
||||||
--source-lang en \
|
--source-lang en \
|
||||||
--target-lang de \
|
--target-lang de \
|
||||||
--tokens ./tokens.txt \
|
--tokens ./tokens.txt \
|
||||||
@@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \
|
|||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
python3 ./test_180m_flash.py \
|
||||||
--encoder ./encoder.int8.onnx \
|
--encoder ./encoder.int8.onnx \
|
||||||
--decoder ./decoder.fp16.onnx \
|
--decoder ./decoder.int8.onnx \
|
||||||
--source-lang de \
|
--source-lang de \
|
||||||
--target-lang de \
|
--target-lang de \
|
||||||
--tokens ./tokens.txt \
|
--tokens ./tokens.txt \
|
||||||
@@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \
|
|||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
python3 ./test_180m_flash.py \
|
||||||
--encoder ./encoder.int8.onnx \
|
--encoder ./encoder.int8.onnx \
|
||||||
--decoder ./decoder.fp16.onnx \
|
--decoder ./decoder.int8.onnx \
|
||||||
--source-lang de \
|
|
||||||
--target-lang en \
|
|
||||||
--tokens ./tokens.txt \
|
|
||||||
--wav ./de.wav
|
|
||||||
|
|
||||||
log "-----fp16------"
|
|
||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
|
||||||
--encoder ./encoder.fp16.onnx \
|
|
||||||
--decoder ./decoder.fp16.onnx \
|
|
||||||
--source-lang en \
|
|
||||||
--target-lang en \
|
|
||||||
--tokens ./tokens.txt \
|
|
||||||
--wav ./en.wav
|
|
||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
|
||||||
--encoder ./encoder.fp16.onnx \
|
|
||||||
--decoder ./decoder.fp16.onnx \
|
|
||||||
--source-lang en \
|
|
||||||
--target-lang de \
|
|
||||||
--tokens ./tokens.txt \
|
|
||||||
--wav ./en.wav
|
|
||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
|
||||||
--encoder ./encoder.fp16.onnx \
|
|
||||||
--decoder ./decoder.fp16.onnx \
|
|
||||||
--source-lang de \
|
|
||||||
--target-lang de \
|
|
||||||
--tokens ./tokens.txt \
|
|
||||||
--wav ./de.wav
|
|
||||||
|
|
||||||
python3 ./test_180m_flash.py \
|
|
||||||
--encoder ./encoder.fp16.onnx \
|
|
||||||
--decoder ./decoder.fp16.onnx \
|
|
||||||
--source-lang de \
|
--source-lang de \
|
||||||
--target-lang en \
|
--target-lang en \
|
||||||
--tokens ./tokens.txt \
|
--tokens ./tokens.txt \
|
||||||
|
|||||||
@@ -79,8 +79,7 @@ class OnnxModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
meta = self.encoder.get_modelmeta().custom_metadata_map
|
meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
# self.normalize_type = meta["normalize_type"]
|
self.normalize_type = meta["normalize_type"]
|
||||||
self.normalize_type = "per_feature"
|
|
||||||
print(meta)
|
print(meta)
|
||||||
|
|
||||||
def init_decoder(self, decoder):
|
def init_decoder(self, decoder):
|
||||||
@@ -267,7 +266,7 @@ def main():
|
|||||||
|
|
||||||
for pos, decoder_input_id in enumerate(decoder_input_ids):
|
for pos, decoder_input_id in enumerate(decoder_input_ids):
|
||||||
logits, decoder_mems_list = model.run_decoder(
|
logits, decoder_mems_list = model.run_decoder(
|
||||||
np.array([[decoder_input_id,pos]], dtype=np.int32),
|
np.array([[decoder_input_id, pos]], dtype=np.int32),
|
||||||
decoder_mems_list,
|
decoder_mems_list,
|
||||||
enc_states,
|
enc_states,
|
||||||
enc_masks,
|
enc_masks,
|
||||||
|
|||||||
Reference in New Issue
Block a user