diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc index 4a89a618..adfce0c9 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc @@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl { } std::pair, std::vector>> Run( - std::vector features, - std::vector> states) const { + std::vector features, std::vector> states) { std::vector inputs(input_attrs_.size()); for (int32_t i = 0; i < static_cast(inputs.size()); ++i) { @@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl { } } - auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); + rknn_context ctx = 0; + auto ret = rknn_dup_context(&ctx_, &ctx); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx"); + + ret = rknn_inputs_set(ctx, inputs.size(), inputs.data()); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); - ret = rknn_run(ctx_, nullptr); + ret = rknn_run(ctx, nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); - ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); + ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); for (int32_t i = 0; i < next_states.size(); ++i) { @@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl { } } + rknn_destroy(ctx); + return {std::move(out), std::move(next_states)}; } diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc index 69199944..5e2fdf5a 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc @@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl { } std::pair, std::vector>> RunEncoder( - std::vector features, - std::vector> states) const { + std::vector features, std::vector> states) { std::vector inputs(encoder_input_attrs_.size()); for (int32_t i = 0; i < static_cast(inputs.size()); ++i) { @@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl { } } - auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data()); + rknn_context encoder_ctx = 0; + + // https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h#L444C1-L444C75 + // rknn_dup_context(rknn_context* context_in, rknn_context* context_out); + auto ret = rknn_dup_context(&encoder_ctx_, &encoder_ctx); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the encoder ctx"); + + ret = rknn_inputs_set(encoder_ctx, inputs.size(), inputs.data()); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs"); - ret = rknn_run(encoder_ctx_, nullptr); + ret = rknn_run(encoder_ctx, nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder"); ret = - rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr); + rknn_outputs_get(encoder_ctx, outputs.size(), outputs.data(), nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output"); for (int32_t i = 0; i < next_states.size(); ++i) { @@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl { } } + rknn_destroy(encoder_ctx); + return {std::move(encoder_out), std::move(next_states)}; } - std::vector RunDecoder(std::vector decoder_input) const { + std::vector RunDecoder(std::vector decoder_input) { auto &attr = decoder_input_attrs_[0]; rknn_input input; @@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl { output.size = decoder_out.size() * sizeof(float); output.buf = decoder_out.data(); - auto ret = rknn_inputs_set(decoder_ctx_, 1, &input); + rknn_context decoder_ctx = 0; + auto ret = rknn_dup_context(&decoder_ctx_, &decoder_ctx); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the decoder ctx"); + + ret = rknn_inputs_set(decoder_ctx, 1, &input); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs"); - ret = rknn_run(decoder_ctx_, nullptr); + ret = rknn_run(decoder_ctx, nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder"); - ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr); + ret = rknn_outputs_get(decoder_ctx, 1, &output, nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output"); + rknn_destroy(decoder_ctx); + return decoder_out; } std::vector RunJoiner(const float *encoder_out, - const float *decoder_out) const { + const float *decoder_out) { std::vector inputs(2); inputs[0].index = 0; inputs[0].type = RKNN_TENSOR_FLOAT32; @@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl { output.size = joiner_out.size() * sizeof(float); output.buf = joiner_out.data(); - auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data()); + rknn_context joiner_ctx = 0; + auto ret = rknn_dup_context(&joiner_ctx_, &joiner_ctx); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the joiner ctx"); + + ret = rknn_inputs_set(joiner_ctx, inputs.size(), inputs.data()); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs"); - ret = rknn_run(joiner_ctx_, nullptr); + ret = rknn_run(joiner_ctx, nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner"); - ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr); + ret = rknn_outputs_get(joiner_ctx, 1, &output, nullptr); SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output"); + rknn_destroy(joiner_ctx); + return joiner_out; }