73 lines
2.0 KiB
Python
Executable File
73 lines
2.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
|
"""
|
|
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
|
|
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
|
|
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
|
|
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
|
|
-----
|
|
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
|
|
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
|
|
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
|
|
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
|
|
"""
|
|
|
|
import onnx
|
|
import onnxruntime as ort
|
|
|
|
|
|
def show(filename):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.log_severity_level = 3
|
|
sess = ort.InferenceSession(filename, session_opts)
|
|
for i in sess.get_inputs():
|
|
print(i)
|
|
|
|
print("-----")
|
|
|
|
for i in sess.get_outputs():
|
|
print(i)
|
|
|
|
|
|
def main():
|
|
filename = "./gtcrn_simple.onnx"
|
|
show(filename)
|
|
model = onnx.load(filename)
|
|
|
|
meta_data = {
|
|
"model_type": "gtcrn",
|
|
"comment": "gtcrn_simple",
|
|
"version": 1,
|
|
"sample_rate": 16000,
|
|
"model_url": "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx",
|
|
"maintainer": "k2-fsa",
|
|
"comment2": "Please see also https://github.com/Xiaobin-Rong/gtcrn",
|
|
"conv_cache_shape": "2,1,16,16,33",
|
|
"tra_cache_shape": "2,3,1,1,16",
|
|
"inter_cache_shape": "2,1,33,16",
|
|
"n_fft": 512,
|
|
"hop_length": 256,
|
|
"window_length": 512,
|
|
"window_type": "hann_sqrt",
|
|
}
|
|
|
|
print(model.metadata_props)
|
|
|
|
while len(model.metadata_props):
|
|
model.metadata_props.pop()
|
|
|
|
for key, value in meta_data.items():
|
|
meta = model.metadata_props.add()
|
|
meta.key = key
|
|
meta.value = str(value)
|
|
print("--------------------")
|
|
|
|
print(model.metadata_props)
|
|
|
|
onnx.save(model, filename)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|