Fix rknn for multi-threads (#2274)

This commit is contained in:
Fangjun Kuang
2025-06-03 20:28:57 +08:00
committed by GitHub
parent 818b3f6d6c
commit 1fabc6c79a
2 changed files with 43 additions and 18 deletions

View File

@@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl {
} }
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run( std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
std::vector<float> features, std::vector<float> features, std::vector<std::vector<uint8_t>> states) {
std::vector<std::vector<uint8_t>> states) const {
std::vector<rknn_input> inputs(input_attrs_.size()); std::vector<rknn_input> inputs(input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
@@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl {
} }
} }
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); rknn_context ctx = 0;
auto ret = rknn_dup_context(&ctx_, &ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx");
ret = rknn_inputs_set(ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx_, nullptr); ret = rknn_run(ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
for (int32_t i = 0; i < next_states.size(); ++i) { for (int32_t i = 0; i < next_states.size(); ++i) {
@@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl {
} }
} }
rknn_destroy(ctx);
return {std::move(out), std::move(next_states)}; return {std::move(out), std::move(next_states)};
} }

View File

@@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
} }
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder( std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder(
std::vector<float> features, std::vector<float> features, std::vector<std::vector<uint8_t>> states) {
std::vector<std::vector<uint8_t>> states) const {
std::vector<rknn_input> inputs(encoder_input_attrs_.size()); std::vector<rknn_input> inputs(encoder_input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
@@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
} }
} }
auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data()); rknn_context encoder_ctx = 0;
// https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h#L444C1-L444C75
// rknn_dup_context(rknn_context* context_in, rknn_context* context_out);
auto ret = rknn_dup_context(&encoder_ctx_, &encoder_ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the encoder ctx");
ret = rknn_inputs_set(encoder_ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs");
ret = rknn_run(encoder_ctx_, nullptr); ret = rknn_run(encoder_ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder");
ret = ret =
rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr); rknn_outputs_get(encoder_ctx, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output");
for (int32_t i = 0; i < next_states.size(); ++i) { for (int32_t i = 0; i < next_states.size(); ++i) {
@@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl {
} }
} }
rknn_destroy(encoder_ctx);
return {std::move(encoder_out), std::move(next_states)}; return {std::move(encoder_out), std::move(next_states)};
} }
std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) const { std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) {
auto &attr = decoder_input_attrs_[0]; auto &attr = decoder_input_attrs_[0];
rknn_input input; rknn_input input;
@@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl {
output.size = decoder_out.size() * sizeof(float); output.size = decoder_out.size() * sizeof(float);
output.buf = decoder_out.data(); output.buf = decoder_out.data();
auto ret = rknn_inputs_set(decoder_ctx_, 1, &input); rknn_context decoder_ctx = 0;
auto ret = rknn_dup_context(&decoder_ctx_, &decoder_ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the decoder ctx");
ret = rknn_inputs_set(decoder_ctx, 1, &input);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs");
ret = rknn_run(decoder_ctx_, nullptr); ret = rknn_run(decoder_ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder");
ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr); ret = rknn_outputs_get(decoder_ctx, 1, &output, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output");
rknn_destroy(decoder_ctx);
return decoder_out; return decoder_out;
} }
std::vector<float> RunJoiner(const float *encoder_out, std::vector<float> RunJoiner(const float *encoder_out,
const float *decoder_out) const { const float *decoder_out) {
std::vector<rknn_input> inputs(2); std::vector<rknn_input> inputs(2);
inputs[0].index = 0; inputs[0].index = 0;
inputs[0].type = RKNN_TENSOR_FLOAT32; inputs[0].type = RKNN_TENSOR_FLOAT32;
@@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
output.size = joiner_out.size() * sizeof(float); output.size = joiner_out.size() * sizeof(float);
output.buf = joiner_out.data(); output.buf = joiner_out.data();
auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data()); rknn_context joiner_ctx = 0;
auto ret = rknn_dup_context(&joiner_ctx_, &joiner_ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the joiner ctx");
ret = rknn_inputs_set(joiner_ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs");
ret = rknn_run(joiner_ctx_, nullptr); ret = rknn_run(joiner_ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner");
ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr); ret = rknn_outputs_get(joiner_ctx, 1, &output, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output"); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output");
rknn_destroy(joiner_ctx);
return joiner_out; return joiner_out;
} }