Export speaker verification models from NeMo to ONNX (#526)
This commit is contained in:
104
scripts/nemo/speaker-verification/export-onnx.py
Executable file
104
scripts/nemo/speaker-verification/export-onnx.py
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"speakerverification_speakernet",
|
||||
"titanet_large",
|
||||
"titanet_small",
|
||||
"ecapa_tdnn",
|
||||
],
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
speaker_model_config = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
|
||||
model_name=args.model, return_config=True
|
||||
)
|
||||
preprocessor_config = speaker_model_config["preprocessor"]
|
||||
|
||||
print(args.model)
|
||||
print(speaker_model_config)
|
||||
print(preprocessor_config)
|
||||
|
||||
assert preprocessor_config["n_fft"] == 512, preprocessor_config
|
||||
|
||||
assert (
|
||||
preprocessor_config["_target_"]
|
||||
== "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor"
|
||||
), preprocessor_config
|
||||
|
||||
assert preprocessor_config["frame_splicing"] == 1, preprocessor_config
|
||||
|
||||
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
|
||||
model_name=args.model
|
||||
)
|
||||
speaker_model.eval()
|
||||
filename = f"nemo_en_{args.model}.onnx"
|
||||
speaker_model.export(filename)
|
||||
|
||||
print(f"Adding metadata to {filename}")
|
||||
|
||||
comment = "This model is from NeMo."
|
||||
url = {
|
||||
"titanet_large": "https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_large",
|
||||
"titanet_small": "https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_small",
|
||||
"speakerverification_speakernet": "https://ngc.nvidia.com/catalog/models/nvidia:nemo:speakerverification_speakernet",
|
||||
"ecapa_tdnn": "https://ngc.nvidia.com/catalog/models/nvidia:nemo:ecapa_tdnn",
|
||||
}[args.model]
|
||||
|
||||
language = "English"
|
||||
|
||||
meta_data = {
|
||||
"framework": "nemo",
|
||||
"language": language,
|
||||
"url": url,
|
||||
"comment": comment,
|
||||
"sample_rate": preprocessor_config["sample_rate"],
|
||||
"output_dim": speaker_model_config["decoder"]["emb_sizes"],
|
||||
"feature_normalize_type": preprocessor_config["normalize"],
|
||||
"window_size_ms": int(float(preprocessor_config["window_size"]) * 1000),
|
||||
"window_stride_ms": int(float(preprocessor_config["window_stride"]) * 1000),
|
||||
"window_type": preprocessor_config["window"], # e.g., hann
|
||||
"feat_dim": preprocessor_config["features"],
|
||||
}
|
||||
print(meta_data)
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user