133 lines
3.5 KiB
Python
Executable File
133 lines
3.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
|
# pip install git+https://github.com/wenet-e2e/wenet.git
|
|
# pip install onnxruntime onnx pyyaml
|
|
# cp -a ~/open-source/wenet/wenet/transducer/search .
|
|
# cp -a ~/open-source//wenet/wenet/e_branchformer .
|
|
# cp -a ~/open-source/wenet/wenet/ctl_model .
|
|
|
|
import os
|
|
from typing import Dict
|
|
|
|
import onnx
|
|
import torch
|
|
import yaml
|
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
|
|
from wenet.utils.init_model import init_model
|
|
|
|
|
|
class Foo:
|
|
pass
|
|
|
|
|
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
|
"""Add meta data to an ONNX model. It is changed in-place.
|
|
|
|
Args:
|
|
filename:
|
|
Filename of the ONNX model to be changed.
|
|
meta_data:
|
|
Key-value pairs.
|
|
"""
|
|
model = onnx.load(filename)
|
|
for key, value in meta_data.items():
|
|
meta = model.metadata_props.add()
|
|
meta.key = key
|
|
meta.value = str(value)
|
|
|
|
onnx.save(model, filename)
|
|
|
|
|
|
class OnnxModel(torch.nn.Module):
|
|
def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
self.ctc = ctc
|
|
|
|
def forward(self, x, x_lens):
|
|
"""
|
|
Args:
|
|
x:
|
|
A 3-D tensor of shape (N, T, C)
|
|
x_lens:
|
|
A 1-D tensor of shape (N,) containing valid lengths in x before
|
|
padding. Its type is torch.int64
|
|
"""
|
|
encoder_out, encoder_out_mask = self.encoder(
|
|
x,
|
|
x_lens,
|
|
decoding_chunk_size=-1,
|
|
num_decoding_left_chunks=-1,
|
|
)
|
|
log_probs = self.ctc.log_softmax(encoder_out)
|
|
log_probs_lens = encoder_out_mask.int().squeeze(1).sum(1)
|
|
|
|
return log_probs, log_probs_lens
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
args = Foo()
|
|
args.checkpoint = "./final.pt"
|
|
config_file = "./train.yaml"
|
|
|
|
with open(config_file, "r") as fin:
|
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
torch_model, configs = init_model(args, configs)
|
|
torch_model.eval()
|
|
|
|
onnx_model = OnnxModel(encoder=torch_model.encoder, ctc=torch_model.ctc)
|
|
filename = "model.onnx"
|
|
|
|
N = 1
|
|
T = 1000
|
|
C = 80
|
|
x = torch.rand(N, T, C, dtype=torch.float)
|
|
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
|
|
|
|
opset_version = 13
|
|
onnx_model = torch.jit.script(onnx_model)
|
|
torch.onnx.export(
|
|
onnx_model,
|
|
(x, x_lens),
|
|
filename,
|
|
opset_version=opset_version,
|
|
input_names=["x", "x_lens"],
|
|
output_names=["log_probs", "log_probs_lens"],
|
|
dynamic_axes={
|
|
"x": {0: "N", 1: "T"},
|
|
"x_lens": {0: "N"},
|
|
"log_probs": {0: "N", 1: "T"},
|
|
"log_probs_lens": {0: "N"},
|
|
},
|
|
)
|
|
|
|
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
|
|
url = os.environ.get("WENET_URL", "")
|
|
meta_data = {
|
|
"model_type": "wenet_ctc",
|
|
"version": "1",
|
|
"model_author": "wenet",
|
|
"comment": "non-streaming",
|
|
"subsampling_factor": torch_model.encoder.embed.subsampling_rate,
|
|
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
|
|
"url": url,
|
|
}
|
|
add_meta_data(filename=filename, meta_data=meta_data)
|
|
|
|
print("Generate int8 quantization models")
|
|
|
|
filename_int8 = f"model.int8.onnx"
|
|
quantize_dynamic(
|
|
model_input=filename,
|
|
model_output=filename_int8,
|
|
op_types_to_quantize=["MatMul"],
|
|
weight_type=QuantType.QInt8,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|