#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) # pip install git+https://github.com/wenet-e2e/wenet.git # pip install onnxruntime onnx pyyaml # cp -a ~/open-source/wenet/wenet/transducer/search . # cp -a ~/open-source//wenet/wenet/e_branchformer . # cp -a ~/open-source/wenet/wenet/ctl_model . import os from typing import Dict import onnx import torch import yaml from onnxruntime.quantization import QuantType, quantize_dynamic from wenet.utils.init_model import init_model class Foo: pass def add_meta_data(filename: str, meta_data: Dict[str, str]): """Add meta data to an ONNX model. It is changed in-place. Args: filename: Filename of the ONNX model to be changed. meta_data: Key-value pairs. """ model = onnx.load(filename) for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = str(value) onnx.save(model, filename) class OnnxModel(torch.nn.Module): def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module): super().__init__() self.encoder = encoder self.ctc = ctc def forward(self, x, x_lens): """ Args: x: A 3-D tensor of shape (N, T, C) x_lens: A 1-D tensor of shape (N,) containing valid lengths in x before padding. Its type is torch.int64 """ encoder_out, encoder_out_mask = self.encoder( x, x_lens, decoding_chunk_size=-1, num_decoding_left_chunks=-1, ) log_probs = self.ctc.log_softmax(encoder_out) log_probs_lens = encoder_out_mask.int().squeeze(1).sum(1) return log_probs, log_probs_lens @torch.no_grad() def main(): args = Foo() args.checkpoint = "./final.pt" config_file = "./train.yaml" with open(config_file, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) torch_model, configs = init_model(args, configs) torch_model.eval() onnx_model = OnnxModel(encoder=torch_model.encoder, ctc=torch_model.ctc) filename = "model.onnx" N = 1 T = 1000 C = 80 x = torch.rand(N, T, C, dtype=torch.float) x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) opset_version = 13 onnx_model = torch.jit.script(onnx_model) torch.onnx.export( onnx_model, (x, x_lens), filename, opset_version=opset_version, input_names=["x", "x_lens"], output_names=["log_probs", "log_probs_lens"], dynamic_axes={ "x": {0: "N", 1: "T"}, "x_lens": {0: "N"}, "log_probs": {0: "N", 1: "T"}, "log_probs_lens": {0: "N"}, }, ) # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz url = os.environ.get("WENET_URL", "") meta_data = { "model_type": "wenet_ctc", "version": "1", "model_author": "wenet", "comment": "non-streaming", "subsampling_factor": torch_model.encoder.embed.subsampling_rate, "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0], "url": url, } add_meta_data(filename=filename, meta_data=meta_data) print("Generate int8 quantization models") filename_int8 = f"model.int8.onnx" quantize_dynamic( model_input=filename, model_output=filename_int8, op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8, ) if __name__ == "__main__": main()