Add meta data to NeMo canary ONNX models (#2351)
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
<|en|>
|
||||
<|pnc|>
|
||||
<|noitn|>
|
||||
<|nodiarize|>
|
||||
<|notimestamp|>
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import nemo
|
||||
import onnxmltools
|
||||
import onnx
|
||||
import torch
|
||||
from nemo.collections.common.parts import NEG_INF
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
"""
|
||||
@@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
|
||||
from nemo.collections.asr.models import EncDecMultiTaskModel
|
||||
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def lens_to_mask(lens, max_length):
|
||||
@@ -222,7 +244,7 @@ def export_decoder(canary_model):
|
||||
),
|
||||
"decoder.onnx",
|
||||
dynamo=True,
|
||||
opset_version=18,
|
||||
opset_version=14,
|
||||
external_data=False,
|
||||
input_names=[
|
||||
"decoder_input_ids",
|
||||
@@ -269,6 +291,29 @@ def export_tokens(canary_model):
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
|
||||
canary_model.eval()
|
||||
|
||||
preprocessor = canary_model.cfg["preprocessor"]
|
||||
sample_rate = preprocessor["sample_rate"]
|
||||
normalize_type = preprocessor["normalize"]
|
||||
window_size = preprocessor["window_size"] # ms
|
||||
window_stride = preprocessor["window_stride"] # ms
|
||||
window = preprocessor["window"]
|
||||
features = preprocessor["features"]
|
||||
n_fft = preprocessor["n_fft"]
|
||||
vocab_size = canary_model.tokenizer.vocab_size # 5248
|
||||
|
||||
subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
|
||||
|
||||
assert sample_rate == 16000, sample_rate
|
||||
assert normalize_type == "per_feature", normalize_type
|
||||
assert window_size == 0.025, window_size
|
||||
assert window_stride == 0.01, window_stride
|
||||
assert window == "hann", window
|
||||
assert features == 128, features
|
||||
assert n_fft == 512, n_fft
|
||||
assert subsampling_factor == 8, subsampling_factor
|
||||
|
||||
export_tokens(canary_model)
|
||||
export_encoder(canary_model)
|
||||
export_decoder(canary_model)
|
||||
@@ -280,7 +325,32 @@ def main():
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
|
||||
meta_data = {
|
||||
"vocab_size": vocab_size,
|
||||
"normalize_type": normalize_type,
|
||||
"subsampling_factor": subsampling_factor,
|
||||
"model_type": "EncDecMultiTaskModel",
|
||||
"version": "1",
|
||||
"model_author": "NeMo",
|
||||
"url": "https://huggingface.co/nvidia/canary-180m-flash",
|
||||
"feat_dim": features,
|
||||
}
|
||||
|
||||
add_meta_data("encoder.onnx", meta_data)
|
||||
add_meta_data("encoder.int8.onnx", meta_data)
|
||||
|
||||
"""
|
||||
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
|
||||
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
|
||||
Unsupported model IR version: 10, max supported IR version: 9
|
||||
"""
|
||||
for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
|
||||
model = onnx.load(filename)
|
||||
print("old", model.ir_version)
|
||||
model.ir_version = 9
|
||||
print("new", model.ir_version)
|
||||
onnx.save(model, filename)
|
||||
|
||||
os.system("ls -lh *.onnx")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user