This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/scripts/wespeaker/add_meta_data.py
2024-01-10 21:09:45 +08:00

144 lines
3.2 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script adds meta data to a model so that it can be used in sherpa-onnx.
Usage:
./add_meta_data.py --model ./voxceleb_resnet34.onnx --language English
"""
import argparse
from pathlib import Path
from typing import Dict
import onnx
import onnxruntime
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the input onnx model. Example value: model.onnx",
)
parser.add_argument(
"--language",
type=str,
required=True,
help="""Supported language of the input model.
Example value: Chinese, English.
""",
)
parser.add_argument(
"--url",
type=str,
default="https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md",
help="Where the model is downloaded",
)
parser.add_argument(
"--comment",
type=str,
default="no comment",
help="Comment about the model",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="Sample rate expected by the model",
)
return parser.parse_args()
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
def get_output_dim(filename) -> int:
filename = str(filename)
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3 # error level
sess = onnxruntime.InferenceSession(filename, session_opts)
for i in sess.get_inputs():
print(i)
print("----------")
for o in sess.get_outputs():
print(o)
print("----------")
assert len(sess.get_inputs()) == 1
assert len(sess.get_outputs()) == 1
i = sess.get_inputs()[0]
o = sess.get_outputs()[0]
assert i.shape[:2] == ["B", "T"], i.shape
assert o.shape[0] == "B"
assert i.shape[2] == 80, i.shape
return o.shape[1]
def main():
args = get_args()
model = Path(args.model)
language = args.language
url = args.url
comment = args.comment
sample_rate = args.sample_rate
if not model.is_file():
raise ValueError(f"{model} does not exist")
assert len(language) > 0, len(language)
assert len(url) > 0, len(url)
output_dim = get_output_dim(model)
# all models from wespeaker expect input samples in the range
# [-32768, 32767]
normalize_samples = 0
meta_data = {
"framework": "wespeaker",
"language": language,
"url": url,
"comment": comment,
"sample_rate": sample_rate,
"output_dim": output_dim,
"normalize_samples": normalize_samples,
}
print(meta_data)
add_meta_data(filename=str(model), meta_data=meta_data)
if __name__ == "__main__":
main()