Add scripts to export ASR models from wenet to ONNX (#425)
See https://user-images.githubusercontent.com/5284924/282995968-f6d39118-8008-4ce7-9d7c-d1d6387ac183.png
This commit is contained in:
203
scripts/wenet/export-onnx-streaming.py
Executable file
203
scripts/wenet/export-onnx-streaming.py
Executable file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
# pip install git+https://github.com/wenet-e2e/wenet.git
|
||||
# pip install onnxruntime onnx pyyaml
|
||||
# cp -a ~/open-source/wenet/wenet/transducer/search .
|
||||
# cp -a ~/open-source//wenet/wenet/e_branchformer .
|
||||
# cp -a ~/open-source/wenet/wenet/ctl_model .
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import yaml
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
from wenet.utils.init_model import init_model
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxModel(torch.nn.Module):
|
||||
def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.ctc = ctc
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
required_cache_size: torch.Tensor,
|
||||
attn_cache: torch.Tensor,
|
||||
conv_cache: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D float32 tensor of shape (N, T, C). It supports only N == 1.
|
||||
offset:
|
||||
A scalar of dtype torch.int64.
|
||||
required_cache_size:
|
||||
A scalar of dtype torch.int64.
|
||||
attn_cache:
|
||||
A 4-D float32 tensor of shape (num_blocks, head, required_cache_size, encoder_output_size / head /2).
|
||||
conv_cache:
|
||||
A 4-D float32 tensor of shape (num_blocks, N, encoder_output_size, cnn_module_kernel - 1).
|
||||
attn_mask:
|
||||
A 3-D bool tensor of shape (N, 1, required_cache_size + chunk_size)
|
||||
Returns:
|
||||
Return a tuple of 3 tensors:
|
||||
- A 3-D float32 tensor of shape (N, T, C) containing log_probs
|
||||
- next_attn_cache
|
||||
- next_conv_cache
|
||||
"""
|
||||
encoder_out, next_att_cache, next_conv_cache = self.encoder.forward_chunk(
|
||||
xs=x,
|
||||
offset=offset,
|
||||
required_cache_size=required_cache_size,
|
||||
att_cache=attn_cache,
|
||||
cnn_cache=conv_cache,
|
||||
att_mask=attn_mask,
|
||||
)
|
||||
log_probs = self.ctc.log_softmax(encoder_out)
|
||||
|
||||
return log_probs, next_att_cache, next_conv_cache
|
||||
|
||||
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = Foo()
|
||||
args.checkpoint = "./final.pt"
|
||||
config_file = "./train.yaml"
|
||||
|
||||
with open(config_file, "r") as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
torch_model, configs = init_model(args, configs)
|
||||
torch_model.eval()
|
||||
|
||||
head = configs["encoder_conf"]["attention_heads"]
|
||||
num_blocks = configs["encoder_conf"]["num_blocks"]
|
||||
output_size = configs["encoder_conf"]["output_size"]
|
||||
cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1)
|
||||
|
||||
right_context = torch_model.right_context()
|
||||
subsampling_factor = torch_model.encoder.embed.subsampling_rate
|
||||
chunk_size = 16
|
||||
left_chunks = 4
|
||||
|
||||
decoding_window = (chunk_size - 1) * subsampling_factor + right_context + 1
|
||||
|
||||
required_cache_size = chunk_size * left_chunks
|
||||
|
||||
offset = required_cache_size
|
||||
|
||||
attn_cache = torch.zeros(
|
||||
num_blocks,
|
||||
head,
|
||||
required_cache_size,
|
||||
output_size // head * 2,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
attn_mask = torch.ones(1, 1, required_cache_size + chunk_size, dtype=torch.bool)
|
||||
attn_mask[:, :, :required_cache_size] = 0
|
||||
|
||||
conv_cache = torch.zeros(
|
||||
num_blocks, 1, output_size, cnn_module_kernel - 1, dtype=torch.float32
|
||||
)
|
||||
|
||||
sos = torch_model.sos_symbol()
|
||||
eos = torch_model.eos_symbol()
|
||||
|
||||
onnx_model = OnnxModel(
|
||||
encoder=torch_model.encoder,
|
||||
ctc=torch_model.ctc,
|
||||
)
|
||||
filename = "model-streaming.onnx"
|
||||
|
||||
N = 1
|
||||
T = decoding_window
|
||||
C = 80
|
||||
x = torch.rand(N, T, C, dtype=torch.float32)
|
||||
offset = torch.tensor([offset], dtype=torch.int64)
|
||||
required_cache_size = torch.tensor([required_cache_size], dtype=torch.int64)
|
||||
|
||||
opset_version = 13
|
||||
torch.onnx.export(
|
||||
onnx_model,
|
||||
(x, offset, required_cache_size, attn_cache, conv_cache, attn_mask),
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"x",
|
||||
"offset",
|
||||
"required_cache_size",
|
||||
"attn_cache",
|
||||
"conv_cache",
|
||||
"attn_mask",
|
||||
],
|
||||
output_names=["log_probs", "next_att_cache", "next_conv_cache"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"attn_cache": {2: "T"},
|
||||
"log_probs": {0: "N"},
|
||||
"new_attn_cache": {2: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
|
||||
url = os.environ.get("WENET_URL", "")
|
||||
meta_data = {
|
||||
"model_type": "wenet-ctc",
|
||||
"version": "1",
|
||||
"model_author": "wenet",
|
||||
"comment": "streaming",
|
||||
"url": "https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz",
|
||||
"chunk_size": chunk_size,
|
||||
"left_chunks": left_chunks,
|
||||
"head": head,
|
||||
"num_blocks": num_blocks,
|
||||
"output_size": output_size,
|
||||
"cnn_module_kernel": cnn_module_kernel,
|
||||
"right_context": right_context,
|
||||
"subsampling_factor": subsampling_factor,
|
||||
}
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
print("Generate int8 quantization models")
|
||||
|
||||
filename_int8 = f"model-streaming.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user