Support RKNN for Zipformer CTC models. (#1948)
This commit is contained in:
@@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN)
|
||||
list(APPEND sources
|
||||
./rknn/online-stream-rknn.cc
|
||||
./rknn/online-transducer-greedy-search-decoder-rknn.cc
|
||||
./rknn/online-zipformer-ctc-model-rknn.cc
|
||||
./rknn/online-zipformer-transducer-model-rknn.cc
|
||||
./rknn/utils.cc
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -43,12 +43,14 @@ class OnlineCtcDecoder {
|
||||
|
||||
/** Run streaming CTC decoding given the output from the encoder model.
|
||||
*
|
||||
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
|
||||
* lob_probs.
|
||||
* @param log_probs A 3-D tensor of shape
|
||||
* (batch_size, num_frames, vocab_size) containing
|
||||
* lob_probs in row major.
|
||||
*
|
||||
* @param results Input & Output parameters..
|
||||
*/
|
||||
virtual void Decode(Ort::Value log_probs,
|
||||
virtual void Decode(const float *log_probs, int32_t batch_size,
|
||||
int32_t num_frames, int32_t vocab_size,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||
|
||||
|
||||
@@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
|
||||
processed_frames += num_rows;
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
|
||||
void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size,
|
||||
int32_t num_frames, int32_t vocab_size,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss, int32_t n) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (log_probs_shape[0] != results->size()) {
|
||||
if (batch_size != results->size()) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]),
|
||||
static_cast<int32_t>(results->size()));
|
||||
batch_size, static_cast<int32_t>(results->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (log_probs_shape[0] != n) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]), n);
|
||||
if (batch_size != n) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size,
|
||||
n);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
||||
|
||||
const float *p = log_probs.GetTensorData<float>();
|
||||
const float *p = log_probs;
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
|
||||
|
||||
@@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder {
|
||||
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
|
||||
int32_t blank_id);
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
|
||||
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
|
||||
@@ -13,23 +13,16 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineCtcGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
|
||||
const float *log_probs, int32_t batch_size, int32_t num_frames,
|
||||
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (log_probs_shape[0] != results->size()) {
|
||||
if (batch_size != results->size()) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]),
|
||||
static_cast<int32_t>(results->size()));
|
||||
batch_size, static_cast<int32_t>(results->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
||||
|
||||
const float *p = log_probs.GetTensorData<float>();
|
||||
const float *p = log_probs;
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &r = (*results)[b];
|
||||
|
||||
@@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
|
||||
explicit OnlineCtcGreedySearchDecoder(int32_t blank_id)
|
||||
: blank_id_(blank_id) {}
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
|
||||
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
|
||||
@@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const {
|
||||
transducer.decoder.c_str(), transducer.joiner.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!zipformer2_ctc.model.empty() &&
|
||||
EndsWith(zipformer2_ctc.model, ".rknn")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--provider is %s, which is not rknn, but you pass rknn model "
|
||||
"filename for zipformer2_ctc: '%s'",
|
||||
provider_config.provider.c_str(), zipformer2_ctc.model.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (provider_config.provider == "rknn") {
|
||||
@@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const {
|
||||
transducer.joiner.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!zipformer2_ctc.model.empty() &&
|
||||
EndsWith(zipformer2_ctc.model, ".onnx")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--provider rknn, but you pass onnx model filename for "
|
||||
"zipformer2_ctc: '%s'",
|
||||
zipformer2_ctc.model.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokens_buf.empty() && FileExists(tokens)) {
|
||||
|
||||
@@ -24,12 +24,11 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t segment,
|
||||
int32_t frames_since_start) {
|
||||
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor, int32_t segment,
|
||||
int32_t frames_since_start) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
@@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
|
||||
log_probs_shape[1], log_probs_shape[2], &results, ss, n);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
@@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
auto r =
|
||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(r.text);
|
||||
return r;
|
||||
}
|
||||
@@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
|
||||
log_probs_shape[1], log_probs_shape[2], &results, &s, 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
|
||||
#endif
|
||||
|
||||
@@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
if (config.model_config.provider_config.provider == "rknn") {
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
// Currently, only zipformer v1 is suported for rknn
|
||||
if (config.model_config.transducer.encoder.empty()) {
|
||||
if (config.model_config.transducer.encoder.empty() &&
|
||||
config.model_config.zipformer2_ctc.model.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only Zipformer transducers are currently supported by rknn. "
|
||||
"Fallback to CPU");
|
||||
} else {
|
||||
"Only Zipformer transducers and CTC models are currently supported "
|
||||
"by rknn. Fallback to CPU");
|
||||
} else if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
|
||||
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineRecognizerCtcRknnImpl>(config);
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE(
|
||||
|
||||
204
sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
Normal file
204
sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
Normal file
@@ -0,0 +1,204 @@
|
||||
// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ../online-recognizer-ctc-impl.h
|
||||
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor, int32_t segment,
|
||||
int32_t frames_since_start);
|
||||
|
||||
class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
|
||||
public:
|
||||
explicit OnlineRecognizerCtcRknnImpl(const OnlineRecognizerConfig &config)
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
model_(
|
||||
std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
InitDecoder();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
explicit OnlineRecognizerCtcRknnImpl(Manager *mgr,
|
||||
const OnlineRecognizerConfig &config)
|
||||
: OnlineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
model_(
|
||||
std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
|
||||
sym_(mgr, config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
InitDecoder();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStreamRknn>(config_.feat_config);
|
||||
stream->SetZipformerEncoderStates(model_->GetInitStates());
|
||||
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
|
||||
return stream;
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
DecodeStream(reinterpret_cast<OnlineStreamRknn *>(ss[i]));
|
||||
}
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
|
||||
OnlineCtcDecoderResult decoder_result = s->GetCtcResult();
|
||||
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
auto r =
|
||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(r.text);
|
||||
return r;
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const override {
|
||||
if (!config_.enable_endpoint) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||
|
||||
// frame shift is 10 milliseconds
|
||||
float frame_shift_in_seconds = 0.01;
|
||||
|
||||
// subsampling factor is 4
|
||||
int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4;
|
||||
|
||||
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||
frame_shift_in_seconds);
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const override {
|
||||
// segment is incremented only when the last
|
||||
// result is not empty
|
||||
const auto &r = s->GetCtcResult();
|
||||
if (!r.tokens.empty()) {
|
||||
s->GetCurrentSegment() += 1;
|
||||
}
|
||||
|
||||
// clear result
|
||||
s->SetCtcResult({});
|
||||
|
||||
// clear states
|
||||
reinterpret_cast<OnlineStreamRknn *>(s)->SetZipformerEncoderStates(
|
||||
model_->GetInitStates());
|
||||
|
||||
s->GetFasterDecoderProcessedFrames() = 0;
|
||||
|
||||
// Note: We only update counters. The underlying audio samples
|
||||
// are not discarded.
|
||||
s->Reset();
|
||||
}
|
||||
|
||||
private:
|
||||
void InitDecoder() {
|
||||
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
|
||||
!sym_.Contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.Contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.Contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.Contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config, blank_id);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported decoding method: %s for streaming CTC models",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeStream(OnlineStreamRknn *s) const {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feat_dim = s->FeatureDim();
|
||||
|
||||
const auto num_processed_frames = s->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
s->GetFrames(num_processed_frames, chunk_size);
|
||||
s->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
auto &states = s->GetZipformerEncoderStates();
|
||||
auto p = model_->Run(features, std::move(states));
|
||||
states = std::move(p.second);
|
||||
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
auto attr = model_->GetOutAttr();
|
||||
|
||||
decoder_->Decode(p.first.data(), attr.dims[0], attr.dims[1], attr.dims[2],
|
||||
&results, reinterpret_cast<OnlineStream **>(&s), 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::unique_ptr<OnlineZipformerCtcModelRknn> model_;
|
||||
std::unique_ptr<OnlineCtcDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
Endpoint endpoint_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
|
||||
390
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
Normal file
390
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
Normal file
@@ -0,0 +1,390 @@
|
||||
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineZipformerCtcModelRknn::Impl {
|
||||
public:
|
||||
~Impl() {
|
||||
auto ret = rknn_destroy(ctx_);
|
||||
if (ret != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE("Failed to destroy the context");
|
||||
}
|
||||
}
|
||||
|
||||
explicit Impl(const OnlineModelConfig &config) : config_(config) {
|
||||
{
|
||||
auto buf = ReadFile(config.zipformer2_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
int32_t ret = RKNN_SUCC;
|
||||
switch (config_.num_threads) {
|
||||
case 1:
|
||||
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
|
||||
break;
|
||||
case 0:
|
||||
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
|
||||
break;
|
||||
case -1:
|
||||
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
|
||||
break;
|
||||
case -2:
|
||||
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
|
||||
break;
|
||||
case -3:
|
||||
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
|
||||
break;
|
||||
case -4:
|
||||
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
|
||||
break;
|
||||
default:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
|
||||
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
|
||||
config_.num_threads);
|
||||
break;
|
||||
}
|
||||
if (ret != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to select npu core to run the model (You can ignore it if "
|
||||
"you "
|
||||
"are not using RK3588.");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): Support Android
|
||||
|
||||
std::vector<std::vector<uint8_t>> GetInitStates() const {
|
||||
// input_attrs_[0] is for the feature
|
||||
// input_attrs_[1:] is for states
|
||||
// so we use -1 here
|
||||
std::vector<std::vector<uint8_t>> states(input_attrs_.size() - 1);
|
||||
|
||||
int32_t i = -1;
|
||||
for (auto &attr : input_attrs_) {
|
||||
i += 1;
|
||||
if (i == 0) {
|
||||
// skip processing the attr for features.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||
states[i - 1].resize(attr.n_elems * sizeof(float));
|
||||
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||
states[i - 1].resize(attr.n_elems * sizeof(int64_t));
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type,
|
||||
get_type_string(attr.type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
|
||||
std::vector<float> features,
|
||||
std::vector<std::vector<uint8_t>> states) const {
|
||||
std::vector<rknn_input> inputs(input_attrs_.size());
|
||||
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
|
||||
auto &input = inputs[i];
|
||||
auto &attr = input_attrs_[i];
|
||||
input.index = attr.index;
|
||||
|
||||
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||
input.type = RKNN_TENSOR_FLOAT32;
|
||||
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||
input.type = RKNN_TENSOR_INT64;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||
get_type_string(attr.type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
input.fmt = attr.fmt;
|
||||
if (i == 0) {
|
||||
input.buf = reinterpret_cast<void *>(features.data());
|
||||
input.size = features.size() * sizeof(float);
|
||||
} else {
|
||||
input.buf = reinterpret_cast<void *>(states[i - 1].data());
|
||||
input.size = states[i - 1].size();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> out(output_attrs_[0].n_elems);
|
||||
|
||||
// Note(fangjun): We can reuse the memory from input argument `states`
|
||||
// auto next_states = GetInitStates();
|
||||
auto &next_states = states;
|
||||
|
||||
std::vector<rknn_output> outputs(output_attrs_.size());
|
||||
for (int32_t i = 0; i < outputs.size(); ++i) {
|
||||
auto &output = outputs[i];
|
||||
auto &attr = output_attrs_[i];
|
||||
output.index = attr.index;
|
||||
output.is_prealloc = 1;
|
||||
|
||||
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||
output.want_float = 1;
|
||||
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||
output.want_float = 0;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||
get_type_string(attr.type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
if (i == 0) {
|
||||
output.size = out.size() * sizeof(float);
|
||||
output.buf = reinterpret_cast<void *>(out.data());
|
||||
} else {
|
||||
output.size = next_states[i - 1].size();
|
||||
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
|
||||
|
||||
ret = rknn_run(ctx_, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
|
||||
|
||||
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
|
||||
|
||||
for (int32_t i = 0; i < next_states.size(); ++i) {
|
||||
const auto &attr = input_attrs_[i + 1];
|
||||
if (attr.n_dims == 4) {
|
||||
// TODO(fangjun): The transpose is copied from
|
||||
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
|
||||
// I don't understand why we need to do that.
|
||||
std::vector<uint8_t> dst(next_states[i].size());
|
||||
int32_t n = attr.dims[0];
|
||||
int32_t h = attr.dims[1];
|
||||
int32_t w = attr.dims[2];
|
||||
int32_t c = attr.dims[3];
|
||||
ConvertNCHWtoNHWC(
|
||||
reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
|
||||
reinterpret_cast<float *>(dst.data()));
|
||||
next_states[i] = std::move(dst);
|
||||
}
|
||||
}
|
||||
|
||||
return {std::move(out), std::move(next_states)};
|
||||
}
|
||||
|
||||
int32_t ChunkSize() const { return T_; }
|
||||
|
||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
rknn_tensor_attr GetOutAttr() const { return output_attrs_[0]; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
|
||||
config_.zipformer2_ctc.model.c_str());
|
||||
|
||||
if (config_.debug) {
|
||||
rknn_sdk_version v;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
|
||||
|
||||
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
|
||||
v.drv_version);
|
||||
}
|
||||
|
||||
rknn_input_output_num io_num;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
|
||||
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
|
||||
static_cast<int32_t>(io_num.n_input),
|
||||
static_cast<int32_t>(io_num.n_output));
|
||||
}
|
||||
|
||||
input_attrs_.resize(io_num.n_input);
|
||||
output_attrs_.resize(io_num.n_output);
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &attr : input_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : input_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (auto &attr : output_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : output_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
rknn_custom_string custom_string;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
|
||||
sizeof(custom_string));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
|
||||
}
|
||||
auto meta = Parse(custom_string);
|
||||
|
||||
if (config_.debug) {
|
||||
for (const auto &p : meta) {
|
||||
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (meta.count("T")) {
|
||||
T_ = atoi(meta.at("T").c_str());
|
||||
}
|
||||
|
||||
if (meta.count("decode_chunk_len")) {
|
||||
decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str());
|
||||
}
|
||||
|
||||
vocab_size_ = output_attrs_[0].dims[2];
|
||||
|
||||
if (config_.debug) {
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("T: %{public}d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size);
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("T: %d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (T_ == 0) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid T. Please use the script from icefall to export your model");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
if (decode_chunk_len_ == 0) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid decode_chunk_len. Please use the script from icefall to "
|
||||
"export your model");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineModelConfig config_;
|
||||
rknn_context ctx_ = 0;
|
||||
|
||||
std::vector<rknn_tensor_attr> input_attrs_;
|
||||
std::vector<rknn_tensor_attr> output_attrs_;
|
||||
|
||||
int32_t T_ = 0;
|
||||
int32_t decode_chunk_len_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
};
|
||||
|
||||
OnlineZipformerCtcModelRknn::~OnlineZipformerCtcModelRknn() = default;
|
||||
|
||||
OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||
const OnlineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||
Manager *mgr, const OnlineModelConfig &config)
|
||||
: impl_(std::make_unique<OnlineZipformerCtcModelRknn>(mgr, config)) {}
|
||||
|
||||
std::vector<std::vector<uint8_t>> OnlineZipformerCtcModelRknn::GetInitStates()
|
||||
const {
|
||||
return impl_->GetInitStates();
|
||||
}
|
||||
|
||||
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>>
|
||||
OnlineZipformerCtcModelRknn::Run(
|
||||
std::vector<float> features,
|
||||
std::vector<std::vector<uint8_t>> states) const {
|
||||
return impl_->Run(std::move(features), std::move(states));
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerCtcModelRknn::ChunkSize() const {
|
||||
return impl_->ChunkSize();
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerCtcModelRknn::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerCtcModelRknn::VocabSize() const {
|
||||
return impl_->VocabSize();
|
||||
}
|
||||
|
||||
rknn_tensor_attr OnlineZipformerCtcModelRknn::GetOutAttr() const {
|
||||
return impl_->GetOutAttr();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||
AAssetManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||
NativeResourceManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
46
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
Normal file
46
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
Normal file
@@ -0,0 +1,46 @@
|
||||
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "rknn_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineZipformerCtcModelRknn {
|
||||
public:
|
||||
~OnlineZipformerCtcModelRknn();
|
||||
|
||||
explicit OnlineZipformerCtcModelRknn(const OnlineModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OnlineZipformerCtcModelRknn(Manager *mgr, const OnlineModelConfig &config);
|
||||
|
||||
std::vector<std::vector<uint8_t>> GetInitStates() const;
|
||||
|
||||
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
|
||||
std::vector<float> features,
|
||||
std::vector<std::vector<uint8_t>> states) const;
|
||||
|
||||
int32_t ChunkSize() const;
|
||||
|
||||
int32_t ChunkShift() const;
|
||||
|
||||
int32_t VocabSize() const;
|
||||
|
||||
rknn_tensor_attr GetOutAttr() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
|
||||
@@ -1,6 +1,6 @@
|
||||
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
|
||||
|
||||
@@ -22,68 +22,11 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// chw -> hwc
|
||||
static void Transpose(const float *src, int32_t n, int32_t channel,
|
||||
int32_t height, int32_t width, float *dst) {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
for (int32_t h = 0; h < height; ++h) {
|
||||
for (int32_t w = 0; w < width; ++w) {
|
||||
for (int32_t c = 0; c < channel; ++c) {
|
||||
// dst[h, w, c] = src[c, h, w]
|
||||
dst[i * height * width * channel + h * width * channel + w * channel +
|
||||
c] = src[i * height * width * channel + c * height * width +
|
||||
h * width + w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::string ToString(const rknn_tensor_attr &attr) {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << attr.index;
|
||||
os << ", name: " << attr.name;
|
||||
os << ", shape: (";
|
||||
std::string sep;
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
|
||||
os << sep << attr.dims[i];
|
||||
sep = ",";
|
||||
}
|
||||
os << ")";
|
||||
os << ", n_elems: " << attr.n_elems;
|
||||
os << ", size: " << attr.size;
|
||||
os << ", fmt: " << get_format_string(attr.fmt);
|
||||
os << ", type: " << get_type_string(attr.type);
|
||||
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
|
||||
os << "}";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, std::string> Parse(
|
||||
const rknn_custom_string &custom_string) {
|
||||
std::unordered_map<std::string, std::string> ans;
|
||||
std::vector<std::string> fields;
|
||||
SplitStringToVector(custom_string.string, ";", false, &fields);
|
||||
|
||||
std::vector<std::string> tmp;
|
||||
for (const auto &f : fields) {
|
||||
SplitStringToVector(f, "=", false, &tmp);
|
||||
if (tmp.size() != 2) {
|
||||
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
|
||||
f.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
ans[std::move(tmp[0])] = std::move(tmp[1]);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
class OnlineZipformerTransducerModelRknn::Impl {
|
||||
public:
|
||||
~Impl() {
|
||||
@@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
||||
for (int32_t i = 0; i < next_states.size(); ++i) {
|
||||
const auto &attr = encoder_input_attrs_[i + 1];
|
||||
if (attr.n_dims == 4) {
|
||||
// TODO(fangjun): The transpose is copied from
|
||||
// TODO(fangjun): The ConvertNCHWtoNHWC is copied from
|
||||
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
|
||||
// I don't understand why we need to do that.
|
||||
std::vector<uint8_t> dst(next_states[i].size());
|
||||
@@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
||||
int32_t h = attr.dims[1];
|
||||
int32_t w = attr.dims[2];
|
||||
int32_t c = attr.dims[3];
|
||||
Transpose(reinterpret_cast<const float *>(next_states[i].data()), n, c,
|
||||
h, w, reinterpret_cast<float *>(dst.data()));
|
||||
ConvertNCHWtoNHWC(
|
||||
reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
|
||||
reinterpret_cast<float *>(dst.data()));
|
||||
next_states[i] = std::move(dst);
|
||||
}
|
||||
}
|
||||
@@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("T: %{public}d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_);
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("T: %d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
context_size_ = decoder_input_attrs_[0].dims[1];
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (auto &attr : decoder_output_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
|
||||
@@ -14,8 +14,11 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// this is for zipformer v1, i.e., the folder
|
||||
// pruned_transducer_statelss7_streaming from icefall
|
||||
// this is for zipformer v1 and v2, i.e., the folder
|
||||
// pruned_transducer_statelss7_streaming
|
||||
// and
|
||||
// zipformer
|
||||
// from icefall
|
||||
class OnlineZipformerTransducerModelRknn {
|
||||
public:
|
||||
~OnlineZipformerTransducerModelRknn();
|
||||
|
||||
73
sherpa-onnx/csrc/rknn/utils.cc
Normal file
73
sherpa-onnx/csrc/rknn/utils.cc
Normal file
@@ -0,0 +1,73 @@
|
||||
// sherpa-onnx/csrc/utils.cc
|
||||
//
|
||||
// Copyright 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
|
||||
int32_t height, int32_t width, float *dst) {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
for (int32_t h = 0; h < height; ++h) {
|
||||
for (int32_t w = 0; w < width; ++w) {
|
||||
for (int32_t c = 0; c < channel; ++c) {
|
||||
// dst[h, w, c] = src[c, h, w]
|
||||
dst[i * height * width * channel + h * width * channel + w * channel +
|
||||
c] = src[i * height * width * channel + c * height * width +
|
||||
h * width + w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string ToString(const rknn_tensor_attr &attr) {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << attr.index;
|
||||
os << ", name: " << attr.name;
|
||||
os << ", shape: (";
|
||||
std::string sep;
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
|
||||
os << sep << attr.dims[i];
|
||||
sep = ",";
|
||||
}
|
||||
os << ")";
|
||||
os << ", n_elems: " << attr.n_elems;
|
||||
os << ", size: " << attr.size;
|
||||
os << ", fmt: " << get_format_string(attr.fmt);
|
||||
os << ", type: " << get_type_string(attr.type);
|
||||
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
|
||||
os << "}";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> Parse(
|
||||
const rknn_custom_string &custom_string) {
|
||||
std::unordered_map<std::string, std::string> ans;
|
||||
std::vector<std::string> fields;
|
||||
SplitStringToVector(custom_string.string, ";", false, &fields);
|
||||
|
||||
std::vector<std::string> tmp;
|
||||
for (const auto &f : fields) {
|
||||
SplitStringToVector(f, "=", false, &tmp);
|
||||
if (tmp.size() != 2) {
|
||||
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
|
||||
f.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
ans[std::move(tmp[0])] = std::move(tmp[1]);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
23
sherpa-onnx/csrc/rknn/utils.h
Normal file
23
sherpa-onnx/csrc/rknn/utils.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/csrc/utils.h
|
||||
//
|
||||
// Copyright 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "rknn_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
|
||||
int32_t height, int32_t width, float *dst);
|
||||
|
||||
std::string ToString(const rknn_tensor_attr &attr);
|
||||
|
||||
std::unordered_map<std::string, std::string> Parse(
|
||||
const rknn_custom_string &custom_string);
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_
|
||||
@@ -83,6 +83,7 @@ for a list of pre-trained models to download.
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() < 1) {
|
||||
po.PrintUsage();
|
||||
fprintf(stderr, "Error! Please provide at lease 1 wav file\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user