Add meta data to NeMo canary ONNX models (#2351)
This commit is contained in:
@@ -62,22 +62,7 @@ jobs:
|
||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||
mkdir -p $d
|
||||
cp encoder.int8.onnx $d
|
||||
cp decoder.fp16.onnx $d
|
||||
cp tokens.txt $d
|
||||
|
||||
mkdir $d/test_wavs
|
||||
cp de.wav $d/test_wavs
|
||||
cp en.wav $d/test_wavs
|
||||
|
||||
tar cjfv $d.tar.bz2 $d
|
||||
|
||||
- name: Collect files (fp16)
|
||||
shell: bash
|
||||
run: |
|
||||
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
|
||||
mkdir -p $d
|
||||
cp encoder.fp16.onnx $d
|
||||
cp decoder.fp16.onnx $d
|
||||
cp decoder.int8.onnx $d
|
||||
cp tokens.txt $d
|
||||
|
||||
mkdir $d/test_wavs
|
||||
@@ -101,7 +86,6 @@ jobs:
|
||||
models=(
|
||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
|
||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
|
||||
)
|
||||
|
||||
for m in ${models[@]}; do
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
<|en|>
|
||||
<|pnc|>
|
||||
<|noitn|>
|
||||
<|nodiarize|>
|
||||
<|notimestamp|>
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import nemo
|
||||
import onnxmltools
|
||||
import onnx
|
||||
import torch
|
||||
from nemo.collections.common.parts import NEG_INF
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
"""
|
||||
@@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
|
||||
from nemo.collections.asr.models import EncDecMultiTaskModel
|
||||
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def lens_to_mask(lens, max_length):
|
||||
@@ -222,7 +244,7 @@ def export_decoder(canary_model):
|
||||
),
|
||||
"decoder.onnx",
|
||||
dynamo=True,
|
||||
opset_version=18,
|
||||
opset_version=14,
|
||||
external_data=False,
|
||||
input_names=[
|
||||
"decoder_input_ids",
|
||||
@@ -269,6 +291,29 @@ def export_tokens(canary_model):
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
|
||||
canary_model.eval()
|
||||
|
||||
preprocessor = canary_model.cfg["preprocessor"]
|
||||
sample_rate = preprocessor["sample_rate"]
|
||||
normalize_type = preprocessor["normalize"]
|
||||
window_size = preprocessor["window_size"] # ms
|
||||
window_stride = preprocessor["window_stride"] # ms
|
||||
window = preprocessor["window"]
|
||||
features = preprocessor["features"]
|
||||
n_fft = preprocessor["n_fft"]
|
||||
vocab_size = canary_model.tokenizer.vocab_size # 5248
|
||||
|
||||
subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
|
||||
|
||||
assert sample_rate == 16000, sample_rate
|
||||
assert normalize_type == "per_feature", normalize_type
|
||||
assert window_size == 0.025, window_size
|
||||
assert window_stride == 0.01, window_stride
|
||||
assert window == "hann", window
|
||||
assert features == 128, features
|
||||
assert n_fft == 512, n_fft
|
||||
assert subsampling_factor == 8, subsampling_factor
|
||||
|
||||
export_tokens(canary_model)
|
||||
export_encoder(canary_model)
|
||||
export_decoder(canary_model)
|
||||
@@ -280,7 +325,32 @@ def main():
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
|
||||
meta_data = {
|
||||
"vocab_size": vocab_size,
|
||||
"normalize_type": normalize_type,
|
||||
"subsampling_factor": subsampling_factor,
|
||||
"model_type": "EncDecMultiTaskModel",
|
||||
"version": "1",
|
||||
"model_author": "NeMo",
|
||||
"url": "https://huggingface.co/nvidia/canary-180m-flash",
|
||||
"feat_dim": features,
|
||||
}
|
||||
|
||||
add_meta_data("encoder.onnx", meta_data)
|
||||
add_meta_data("encoder.int8.onnx", meta_data)
|
||||
|
||||
"""
|
||||
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
|
||||
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
|
||||
Unsupported model IR version: 10, max supported IR version: 9
|
||||
"""
|
||||
for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
|
||||
model = onnx.load(filename)
|
||||
print("old", model.ir_version)
|
||||
model.ir_version = 9
|
||||
print("new", model.ir_version)
|
||||
onnx.save(model, filename)
|
||||
|
||||
os.system("ls -lh *.onnx")
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ pip install \
|
||||
kaldi-native-fbank \
|
||||
librosa \
|
||||
onnx==1.17.0 \
|
||||
onnxmltools \
|
||||
onnxruntime==1.17.1 \
|
||||
onnxscript \
|
||||
soundfile
|
||||
|
||||
python3 ./export_onnx_180m_flash.py
|
||||
@@ -66,7 +66,7 @@ log "-----int8------"
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--decoder ./decoder.int8.onnx \
|
||||
--source-lang en \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
@@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--decoder ./decoder.int8.onnx \
|
||||
--source-lang en \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
@@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--decoder ./decoder.int8.onnx \
|
||||
--source-lang de \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
@@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.int8.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang de \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
log "-----fp16------"
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang en \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang en \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./en.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--source-lang de \
|
||||
--target-lang de \
|
||||
--tokens ./tokens.txt \
|
||||
--wav ./de.wav
|
||||
|
||||
python3 ./test_180m_flash.py \
|
||||
--encoder ./encoder.fp16.onnx \
|
||||
--decoder ./decoder.fp16.onnx \
|
||||
--decoder ./decoder.int8.onnx \
|
||||
--source-lang de \
|
||||
--target-lang en \
|
||||
--tokens ./tokens.txt \
|
||||
|
||||
@@ -79,8 +79,7 @@ class OnnxModel:
|
||||
)
|
||||
|
||||
meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
# self.normalize_type = meta["normalize_type"]
|
||||
self.normalize_type = "per_feature"
|
||||
self.normalize_type = meta["normalize_type"]
|
||||
print(meta)
|
||||
|
||||
def init_decoder(self, decoder):
|
||||
@@ -267,7 +266,7 @@ def main():
|
||||
|
||||
for pos, decoder_input_id in enumerate(decoder_input_ids):
|
||||
logits, decoder_mems_list = model.run_decoder(
|
||||
np.array([[decoder_input_id,pos]], dtype=np.int32),
|
||||
np.array([[decoder_input_id, pos]], dtype=np.int32),
|
||||
decoder_mems_list,
|
||||
enc_states,
|
||||
enc_masks,
|
||||
|
||||
Reference in New Issue
Block a user