Refactor rknn code (#2079)

This commit is contained in:
Fangjun Kuang
2025-04-01 16:54:53 +08:00
committed by GitHub
parent 8e51a97550
commit a11e359c11
6 changed files with 218 additions and 451 deletions

View File

@@ -4,12 +4,15 @@
#include "sherpa-onnx/csrc/rknn/utils.h"
#include <string.h>
#include <sstream>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
@@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) {
}
std::unordered_map<std::string, std::string> Parse(
const rknn_custom_string &custom_string) {
const rknn_custom_string &custom_string, bool debug /*= false*/) {
std::unordered_map<std::string, std::string> ans;
std::vector<std::string> fields;
SplitStringToVector(custom_string.string, ";", false, &fields);
@@ -68,7 +71,131 @@ std::unordered_map<std::string, std::string> Parse(
ans[std::move(tmp[0])] = std::move(tmp[1]);
}
if (debug) {
for (const auto &p : ans) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
return ans;
}
void InitContext(void *model_data, size_t model_data_length, bool debug,
rknn_context *ctx) {
auto ret = rknn_init(ctx, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init rknn");
if (debug) {
rknn_sdk_version v;
ret = rknn_query(*ctx, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
}
void InitInputOutputAttrs(rknn_context ctx, bool debug,
std::vector<rknn_tensor_attr> *input_attrs,
std::vector<rknn_tensor_attr> *output_attrs) {
rknn_input_output_num io_num;
auto ret = rknn_query(ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
if (debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
input_attrs->resize(io_num.n_input);
output_attrs->resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : *input_attrs) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}
if (debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : *input_attrs) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : *output_attrs) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
if (debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : *output_attrs) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
}
rknn_custom_string GetCustomString(rknn_context ctx, bool debug) {
rknn_custom_string custom_string;
auto ret = rknn_query(ctx, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
return custom_string;
}
void SetCoreMask(rknn_context ctx, int32_t num_threads) {
int32_t ret = RKNN_SUCC;
switch (num_threads) {
case 1:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_AUTO);
break;
case 0:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0);
break;
case -1:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_1);
break;
case -2:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_2);
break;
case -3:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1);
break;
case -4:
ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1_2);
break;
default:
SHERPA_ONNX_LOGE(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
num_threads);
break;
}
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run the model (You can ignore it if "
"you are not using RK3588.");
}
}
} // namespace sherpa_onnx