155 lines
3.0 KiB
Python
Executable File
155 lines
3.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from rknn.api import RKNN
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
|
|
g_platforms = [
|
|
# "rv1103",
|
|
# "rv1103b",
|
|
# "rv1106",
|
|
# "rk2118",
|
|
"rk3562",
|
|
"rk3566",
|
|
"rk3568",
|
|
"rk3576",
|
|
"rk3588",
|
|
]
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--target-platform",
|
|
type=str,
|
|
required=True,
|
|
help=f"Supported values are: {','.join(g_platforms)}",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--in-model",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the input onnx model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--out-model",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the output rknn model",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_meta_data(model: str):
|
|
import onnxruntime
|
|
|
|
session_opts = onnxruntime.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
m = onnxruntime.InferenceSession(
|
|
model,
|
|
sess_options=session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
for i in m.get_inputs():
|
|
print(i)
|
|
|
|
print("-----")
|
|
|
|
for i in m.get_outputs():
|
|
print(i)
|
|
print()
|
|
|
|
meta = m.get_modelmeta().custom_metadata_map
|
|
s = ""
|
|
sep = ""
|
|
for key, value in meta.items():
|
|
s = s + sep + f"{key}={value}"
|
|
sep = ";"
|
|
assert len(s) < 1024
|
|
|
|
return s
|
|
|
|
|
|
def export_rknn(rknn, filename):
|
|
ret = rknn.export_rknn(filename)
|
|
if ret != 0:
|
|
exit("Export rknn model to {filename} failed!")
|
|
|
|
|
|
def init_model(filename: str, target_platform: str, custom_string=None):
|
|
rknn = RKNN(verbose=False)
|
|
|
|
rknn.config(
|
|
optimization_level=0,
|
|
target_platform=target_platform,
|
|
custom_string=custom_string,
|
|
)
|
|
if not Path(filename).is_file():
|
|
exit(f"{filename} does not exist")
|
|
|
|
ret = rknn.load_onnx(model=filename)
|
|
if ret != 0:
|
|
exit(f"Load model {filename} failed!")
|
|
|
|
ret = rknn.build(do_quantization=False)
|
|
if ret != 0:
|
|
exit("Build model {filename} failed!")
|
|
|
|
return rknn
|
|
|
|
|
|
class RKNNModel:
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
target_platform: str,
|
|
):
|
|
meta = get_meta_data(model)
|
|
print(meta)
|
|
|
|
self.model = init_model(
|
|
model,
|
|
target_platform=target_platform,
|
|
custom_string=meta,
|
|
)
|
|
|
|
def export_rknn(self, model):
|
|
export_rknn(self.model, model)
|
|
|
|
def release(self):
|
|
self.model.release()
|
|
|
|
|
|
def main():
|
|
args = get_parser().parse_args()
|
|
print(vars(args))
|
|
|
|
model = RKNNModel(
|
|
model=args.in_model,
|
|
target_platform=args.target_platform,
|
|
)
|
|
|
|
model.export_rknn(
|
|
model=args.out_model,
|
|
)
|
|
|
|
model.release()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|