Support RKNN for Zipformer CTC models. (#1948)
This commit is contained in:
@@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN)
|
|||||||
list(APPEND sources
|
list(APPEND sources
|
||||||
./rknn/online-stream-rknn.cc
|
./rknn/online-stream-rknn.cc
|
||||||
./rknn/online-transducer-greedy-search-decoder-rknn.cc
|
./rknn/online-transducer-greedy-search-decoder-rknn.cc
|
||||||
|
./rknn/online-zipformer-ctc-model-rknn.cc
|
||||||
./rknn/online-zipformer-transducer-model-rknn.cc
|
./rknn/online-zipformer-transducer-model-rknn.cc
|
||||||
|
./rknn/utils.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -43,12 +43,14 @@ class OnlineCtcDecoder {
|
|||||||
|
|
||||||
/** Run streaming CTC decoding given the output from the encoder model.
|
/** Run streaming CTC decoding given the output from the encoder model.
|
||||||
*
|
*
|
||||||
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
|
* @param log_probs A 3-D tensor of shape
|
||||||
* lob_probs.
|
* (batch_size, num_frames, vocab_size) containing
|
||||||
|
* lob_probs in row major.
|
||||||
*
|
*
|
||||||
* @param results Input & Output parameters..
|
* @param results Input & Output parameters..
|
||||||
*/
|
*/
|
||||||
virtual void Decode(Ort::Value log_probs,
|
virtual void Decode(const float *log_probs, int32_t batch_size,
|
||||||
|
int32_t num_frames, int32_t vocab_size,
|
||||||
std::vector<OnlineCtcDecoderResult> *results,
|
std::vector<OnlineCtcDecoderResult> *results,
|
||||||
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
|
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
|
|||||||
processed_frames += num_rows;
|
processed_frames += num_rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
|
void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size,
|
||||||
|
int32_t num_frames, int32_t vocab_size,
|
||||||
std::vector<OnlineCtcDecoderResult> *results,
|
std::vector<OnlineCtcDecoderResult> *results,
|
||||||
OnlineStream **ss, int32_t n) {
|
OnlineStream **ss, int32_t n) {
|
||||||
std::vector<int64_t> log_probs_shape =
|
if (batch_size != results->size()) {
|
||||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
|
|
||||||
if (log_probs_shape[0] != results->size()) {
|
|
||||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||||
static_cast<int32_t>(log_probs_shape[0]),
|
batch_size, static_cast<int32_t>(results->size()));
|
||||||
static_cast<int32_t>(results->size()));
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (log_probs_shape[0] != n) {
|
if (batch_size != n) {
|
||||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
|
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size,
|
||||||
static_cast<int32_t>(log_probs_shape[0]), n);
|
n);
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
const float *p = log_probs;
|
||||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
|
||||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
|
||||||
|
|
||||||
const float *p = log_probs.GetTensorData<float>();
|
|
||||||
|
|
||||||
for (int32_t i = 0; i != batch_size; ++i) {
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
|
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder {
|
|||||||
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
|
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
|
||||||
int32_t blank_id);
|
int32_t blank_id);
|
||||||
|
|
||||||
void Decode(Ort::Value log_probs,
|
void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
|
||||||
std::vector<OnlineCtcDecoderResult> *results,
|
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
|
||||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||||
|
|
||||||
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||||
|
|||||||
@@ -13,23 +13,16 @@
|
|||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
void OnlineCtcGreedySearchDecoder::Decode(
|
void OnlineCtcGreedySearchDecoder::Decode(
|
||||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
|
const float *log_probs, int32_t batch_size, int32_t num_frames,
|
||||||
|
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
|
||||||
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
|
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
|
||||||
std::vector<int64_t> log_probs_shape =
|
if (batch_size != results->size()) {
|
||||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
|
|
||||||
if (log_probs_shape[0] != results->size()) {
|
|
||||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||||
static_cast<int32_t>(log_probs_shape[0]),
|
batch_size, static_cast<int32_t>(results->size()));
|
||||||
static_cast<int32_t>(results->size()));
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
const float *p = log_probs;
|
||||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
|
||||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
|
||||||
|
|
||||||
const float *p = log_probs.GetTensorData<float>();
|
|
||||||
|
|
||||||
for (int32_t b = 0; b != batch_size; ++b) {
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
auto &r = (*results)[b];
|
auto &r = (*results)[b];
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
|
|||||||
explicit OnlineCtcGreedySearchDecoder(int32_t blank_id)
|
explicit OnlineCtcGreedySearchDecoder(int32_t blank_id)
|
||||||
: blank_id_(blank_id) {}
|
: blank_id_(blank_id) {}
|
||||||
|
|
||||||
void Decode(Ort::Value log_probs,
|
void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames,
|
||||||
std::vector<OnlineCtcDecoderResult> *results,
|
int32_t vocab_size, std::vector<OnlineCtcDecoderResult> *results,
|
||||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const {
|
|||||||
transducer.decoder.c_str(), transducer.joiner.c_str());
|
transducer.decoder.c_str(), transducer.joiner.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!zipformer2_ctc.model.empty() &&
|
||||||
|
EndsWith(zipformer2_ctc.model, ".rknn")) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"--provider is %s, which is not rknn, but you pass rknn model "
|
||||||
|
"filename for zipformer2_ctc: '%s'",
|
||||||
|
provider_config.provider.c_str(), zipformer2_ctc.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (provider_config.provider == "rknn") {
|
if (provider_config.provider == "rknn") {
|
||||||
@@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const {
|
|||||||
transducer.joiner.c_str());
|
transducer.joiner.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!zipformer2_ctc.model.empty() &&
|
||||||
|
EndsWith(zipformer2_ctc.model, ".onnx")) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"--provider rknn, but you pass onnx model filename for "
|
||||||
|
"zipformer2_ctc: '%s'",
|
||||||
|
zipformer2_ctc.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tokens_buf.empty() && FileExists(tokens)) {
|
if (!tokens_buf.empty() && FileExists(tokens)) {
|
||||||
|
|||||||
@@ -24,12 +24,11 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
|
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
|
||||||
const SymbolTable &sym_table,
|
const SymbolTable &sym_table,
|
||||||
float frame_shift_ms,
|
float frame_shift_ms,
|
||||||
int32_t subsampling_factor,
|
int32_t subsampling_factor, int32_t segment,
|
||||||
int32_t segment,
|
int32_t frames_since_start) {
|
||||||
int32_t frames_since_start) {
|
|
||||||
OnlineRecognizerResult r;
|
OnlineRecognizerResult r;
|
||||||
r.tokens.reserve(src.tokens.size());
|
r.tokens.reserve(src.tokens.size());
|
||||||
r.timestamps.reserve(src.tokens.size());
|
r.timestamps.reserve(src.tokens.size());
|
||||||
@@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
std::vector<std::vector<Ort::Value>> next_states =
|
std::vector<std::vector<Ort::Value>> next_states =
|
||||||
model_->UnStackStates(std::move(out_states));
|
model_->UnStackStates(std::move(out_states));
|
||||||
|
|
||||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
std::vector<int64_t> log_probs_shape =
|
||||||
|
out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
|
||||||
|
log_probs_shape[1], log_probs_shape[2], &results, ss, n);
|
||||||
|
|
||||||
for (int32_t k = 0; k != n; ++k) {
|
for (int32_t k = 0; k != n; ++k) {
|
||||||
ss[k]->SetCtcResult(results[k]);
|
ss[k]->SetCtcResult(results[k]);
|
||||||
@@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
// TODO(fangjun): Remember to change these constants if needed
|
// TODO(fangjun): Remember to change these constants if needed
|
||||||
int32_t frame_shift_ms = 10;
|
int32_t frame_shift_ms = 10;
|
||||||
int32_t subsampling_factor = 4;
|
int32_t subsampling_factor = 4;
|
||||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
auto r =
|
||||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
|
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||||
r.text = ApplyInverseTextNormalization(r.text);
|
r.text = ApplyInverseTextNormalization(r.text);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
@@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
std::vector<OnlineCtcDecoderResult> results(1);
|
std::vector<OnlineCtcDecoderResult> results(1);
|
||||||
results[0] = std::move(s->GetCtcResult());
|
results[0] = std::move(s->GetCtcResult());
|
||||||
|
|
||||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
std::vector<int64_t> log_probs_shape =
|
||||||
|
out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
|
||||||
|
log_probs_shape[1], log_probs_shape[2], &results, &s, 1);
|
||||||
s->SetCtcResult(results[0]);
|
s->SetCtcResult(results[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@
|
|||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
#if SHERPA_ONNX_ENABLE_RKNN
|
#if SHERPA_ONNX_ENABLE_RKNN
|
||||||
|
#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h"
|
||||||
#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
|
#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -37,12 +38,15 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
|||||||
if (config.model_config.provider_config.provider == "rknn") {
|
if (config.model_config.provider_config.provider == "rknn") {
|
||||||
#if SHERPA_ONNX_ENABLE_RKNN
|
#if SHERPA_ONNX_ENABLE_RKNN
|
||||||
// Currently, only zipformer v1 is suported for rknn
|
// Currently, only zipformer v1 is suported for rknn
|
||||||
if (config.model_config.transducer.encoder.empty()) {
|
if (config.model_config.transducer.encoder.empty() &&
|
||||||
|
config.model_config.zipformer2_ctc.model.empty()) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Only Zipformer transducers are currently supported by rknn. "
|
"Only Zipformer transducers and CTC models are currently supported "
|
||||||
"Fallback to CPU");
|
"by rknn. Fallback to CPU");
|
||||||
} else {
|
} else if (!config.model_config.transducer.encoder.empty()) {
|
||||||
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
|
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
|
||||||
|
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
|
||||||
|
return std::make_unique<OnlineRecognizerCtcRknnImpl>(config);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
|
|||||||
204
sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
Normal file
204
sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <ios>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// defined in ../online-recognizer-ctc-impl.h
|
||||||
|
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table,
|
||||||
|
float frame_shift_ms,
|
||||||
|
int32_t subsampling_factor, int32_t segment,
|
||||||
|
int32_t frames_since_start);
|
||||||
|
|
||||||
|
class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OnlineRecognizerCtcRknnImpl(const OnlineRecognizerConfig &config)
|
||||||
|
: OnlineRecognizerImpl(config),
|
||||||
|
config_(config),
|
||||||
|
model_(
|
||||||
|
std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
|
||||||
|
endpoint_(config_.endpoint_config) {
|
||||||
|
if (!config.model_config.tokens_buf.empty()) {
|
||||||
|
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||||
|
} else {
|
||||||
|
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||||
|
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
InitDecoder();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
explicit OnlineRecognizerCtcRknnImpl(Manager *mgr,
|
||||||
|
const OnlineRecognizerConfig &config)
|
||||||
|
: OnlineRecognizerImpl(mgr, config),
|
||||||
|
config_(config),
|
||||||
|
model_(
|
||||||
|
std::make_unique<OnlineZipformerCtcModelRknn>(config.model_config)),
|
||||||
|
sym_(mgr, config.model_config.tokens),
|
||||||
|
endpoint_(config_.endpoint_config) {
|
||||||
|
InitDecoder();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||||
|
auto stream = std::make_unique<OnlineStreamRknn>(config_.feat_config);
|
||||||
|
stream->SetZipformerEncoderStates(model_->GetInitStates());
|
||||||
|
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady(OnlineStream *s) const override {
|
||||||
|
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||||
|
s->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
DecodeStream(reinterpret_cast<OnlineStreamRknn *>(ss[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
|
||||||
|
OnlineCtcDecoderResult decoder_result = s->GetCtcResult();
|
||||||
|
|
||||||
|
// TODO(fangjun): Remember to change these constants if needed
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
int32_t subsampling_factor = 4;
|
||||||
|
auto r =
|
||||||
|
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
|
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||||
|
r.text = ApplyInverseTextNormalization(r.text);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsEndpoint(OnlineStream *s) const override {
|
||||||
|
if (!config_.enable_endpoint) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||||
|
|
||||||
|
// frame shift is 10 milliseconds
|
||||||
|
float frame_shift_in_seconds = 0.01;
|
||||||
|
|
||||||
|
// subsampling factor is 4
|
||||||
|
int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4;
|
||||||
|
|
||||||
|
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||||
|
frame_shift_in_seconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset(OnlineStream *s) const override {
|
||||||
|
// segment is incremented only when the last
|
||||||
|
// result is not empty
|
||||||
|
const auto &r = s->GetCtcResult();
|
||||||
|
if (!r.tokens.empty()) {
|
||||||
|
s->GetCurrentSegment() += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear result
|
||||||
|
s->SetCtcResult({});
|
||||||
|
|
||||||
|
// clear states
|
||||||
|
reinterpret_cast<OnlineStreamRknn *>(s)->SetZipformerEncoderStates(
|
||||||
|
model_->GetInitStates());
|
||||||
|
|
||||||
|
s->GetFasterDecoderProcessedFrames() = 0;
|
||||||
|
|
||||||
|
// Note: We only update counters. The underlying audio samples
|
||||||
|
// are not discarded.
|
||||||
|
s->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitDecoder() {
|
||||||
|
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
|
||||||
|
!sym_.Contains("<blank>")) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"We expect that tokens.txt contains "
|
||||||
|
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t blank_id = 0;
|
||||||
|
if (sym_.Contains("<blk>")) {
|
||||||
|
blank_id = sym_["<blk>"];
|
||||||
|
} else if (sym_.Contains("<eps>")) {
|
||||||
|
// for tdnn models of the yesno recipe from icefall
|
||||||
|
blank_id = sym_["<eps>"];
|
||||||
|
} else if (sym_.Contains("<blank>")) {
|
||||||
|
// for WeNet CTC models
|
||||||
|
blank_id = sym_["<blank>"];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||||
|
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
|
||||||
|
config_.ctc_fst_decoder_config, blank_id);
|
||||||
|
} else if (config_.decoding_method == "greedy_search") {
|
||||||
|
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Unsupported decoding method: %s for streaming CTC models",
|
||||||
|
config_.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStream(OnlineStreamRknn *s) const {
|
||||||
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
|
||||||
|
int32_t feat_dim = s->FeatureDim();
|
||||||
|
|
||||||
|
const auto num_processed_frames = s->GetNumProcessedFrames();
|
||||||
|
std::vector<float> features =
|
||||||
|
s->GetFrames(num_processed_frames, chunk_size);
|
||||||
|
s->GetNumProcessedFrames() += chunk_shift;
|
||||||
|
|
||||||
|
auto &states = s->GetZipformerEncoderStates();
|
||||||
|
auto p = model_->Run(features, std::move(states));
|
||||||
|
states = std::move(p.second);
|
||||||
|
|
||||||
|
std::vector<OnlineCtcDecoderResult> results(1);
|
||||||
|
results[0] = std::move(s->GetCtcResult());
|
||||||
|
|
||||||
|
auto attr = model_->GetOutAttr();
|
||||||
|
|
||||||
|
decoder_->Decode(p.first.data(), attr.dims[0], attr.dims[1], attr.dims[2],
|
||||||
|
&results, reinterpret_cast<OnlineStream **>(&s), 1);
|
||||||
|
s->SetCtcResult(results[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineRecognizerConfig config_;
|
||||||
|
std::unique_ptr<OnlineZipformerCtcModelRknn> model_;
|
||||||
|
std::unique_ptr<OnlineCtcDecoder> decoder_;
|
||||||
|
SymbolTable sym_;
|
||||||
|
Endpoint endpoint_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_
|
||||||
390
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
Normal file
390
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
#include "rawfile/raw_file_manager.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineZipformerCtcModelRknn::Impl {
|
||||||
|
public:
|
||||||
|
~Impl() {
|
||||||
|
auto ret = rknn_destroy(ctx_);
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
SHERPA_ONNX_LOGE("Failed to destroy the context");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit Impl(const OnlineModelConfig &config) : config_(config) {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.zipformer2_ctc.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t ret = RKNN_SUCC;
|
||||||
|
switch (config_.num_threads) {
|
||||||
|
case 1:
|
||||||
|
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
|
||||||
|
break;
|
||||||
|
case 0:
|
||||||
|
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
|
||||||
|
break;
|
||||||
|
case -1:
|
||||||
|
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
|
||||||
|
break;
|
||||||
|
case -2:
|
||||||
|
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
|
||||||
|
break;
|
||||||
|
case -3:
|
||||||
|
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
|
||||||
|
break;
|
||||||
|
case -4:
|
||||||
|
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
|
||||||
|
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
|
||||||
|
config_.num_threads);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Failed to select npu core to run the model (You can ignore it if "
|
||||||
|
"you "
|
||||||
|
"are not using RK3588.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(fangjun): Support Android
|
||||||
|
|
||||||
|
std::vector<std::vector<uint8_t>> GetInitStates() const {
|
||||||
|
// input_attrs_[0] is for the feature
|
||||||
|
// input_attrs_[1:] is for states
|
||||||
|
// so we use -1 here
|
||||||
|
std::vector<std::vector<uint8_t>> states(input_attrs_.size() - 1);
|
||||||
|
|
||||||
|
int32_t i = -1;
|
||||||
|
for (auto &attr : input_attrs_) {
|
||||||
|
i += 1;
|
||||||
|
if (i == 0) {
|
||||||
|
// skip processing the attr for features.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||||
|
states[i - 1].resize(attr.n_elems * sizeof(float));
|
||||||
|
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||||
|
states[i - 1].resize(attr.n_elems * sizeof(int64_t));
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type,
|
||||||
|
get_type_string(attr.type));
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return states;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
|
||||||
|
std::vector<float> features,
|
||||||
|
std::vector<std::vector<uint8_t>> states) const {
|
||||||
|
std::vector<rknn_input> inputs(input_attrs_.size());
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
|
||||||
|
auto &input = inputs[i];
|
||||||
|
auto &attr = input_attrs_[i];
|
||||||
|
input.index = attr.index;
|
||||||
|
|
||||||
|
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||||
|
input.type = RKNN_TENSOR_FLOAT32;
|
||||||
|
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||||
|
input.type = RKNN_TENSOR_INT64;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||||
|
get_type_string(attr.type));
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
input.fmt = attr.fmt;
|
||||||
|
if (i == 0) {
|
||||||
|
input.buf = reinterpret_cast<void *>(features.data());
|
||||||
|
input.size = features.size() * sizeof(float);
|
||||||
|
} else {
|
||||||
|
input.buf = reinterpret_cast<void *>(states[i - 1].data());
|
||||||
|
input.size = states[i - 1].size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> out(output_attrs_[0].n_elems);
|
||||||
|
|
||||||
|
// Note(fangjun): We can reuse the memory from input argument `states`
|
||||||
|
// auto next_states = GetInitStates();
|
||||||
|
auto &next_states = states;
|
||||||
|
|
||||||
|
std::vector<rknn_output> outputs(output_attrs_.size());
|
||||||
|
for (int32_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
auto &output = outputs[i];
|
||||||
|
auto &attr = output_attrs_[i];
|
||||||
|
output.index = attr.index;
|
||||||
|
output.is_prealloc = 1;
|
||||||
|
|
||||||
|
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||||
|
output.want_float = 1;
|
||||||
|
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||||
|
output.want_float = 0;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||||
|
get_type_string(attr.type));
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
output.size = out.size() * sizeof(float);
|
||||||
|
output.buf = reinterpret_cast<void *>(out.data());
|
||||||
|
} else {
|
||||||
|
output.size = next_states[i - 1].size();
|
||||||
|
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
|
||||||
|
|
||||||
|
ret = rknn_run(ctx_, nullptr);
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
|
||||||
|
|
||||||
|
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < next_states.size(); ++i) {
|
||||||
|
const auto &attr = input_attrs_[i + 1];
|
||||||
|
if (attr.n_dims == 4) {
|
||||||
|
// TODO(fangjun): The transpose is copied from
|
||||||
|
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
|
||||||
|
// I don't understand why we need to do that.
|
||||||
|
std::vector<uint8_t> dst(next_states[i].size());
|
||||||
|
int32_t n = attr.dims[0];
|
||||||
|
int32_t h = attr.dims[1];
|
||||||
|
int32_t w = attr.dims[2];
|
||||||
|
int32_t c = attr.dims[3];
|
||||||
|
ConvertNCHWtoNHWC(
|
||||||
|
reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
|
||||||
|
reinterpret_cast<float *>(dst.data()));
|
||||||
|
next_states[i] = std::move(dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::move(out), std::move(next_states)};
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t ChunkSize() const { return T_; }
|
||||||
|
|
||||||
|
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||||
|
|
||||||
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
|
rknn_tensor_attr GetOutAttr() const { return output_attrs_[0]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
|
||||||
|
config_.zipformer2_ctc.model.c_str());
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
rknn_sdk_version v;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
|
||||||
|
v.drv_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
rknn_input_output_num io_num;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
|
||||||
|
static_cast<int32_t>(io_num.n_input),
|
||||||
|
static_cast<int32_t>(io_num.n_output));
|
||||||
|
}
|
||||||
|
|
||||||
|
input_attrs_.resize(io_num.n_input);
|
||||||
|
output_attrs_.resize(io_num.n_output);
|
||||||
|
|
||||||
|
int32_t i = 0;
|
||||||
|
for (auto &attr : input_attrs_) {
|
||||||
|
memset(&attr, 0, sizeof(attr));
|
||||||
|
attr.index = i;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
std::string sep;
|
||||||
|
for (auto &attr : input_attrs_) {
|
||||||
|
os << sep << ToString(attr);
|
||||||
|
sep = "\n";
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
|
||||||
|
os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
for (auto &attr : output_attrs_) {
|
||||||
|
memset(&attr, 0, sizeof(attr));
|
||||||
|
attr.index = i;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
std::string sep;
|
||||||
|
for (auto &attr : output_attrs_) {
|
||||||
|
os << sep << ToString(attr);
|
||||||
|
sep = "\n";
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
|
||||||
|
os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
rknn_custom_string custom_string;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
|
||||||
|
sizeof(custom_string));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
|
||||||
|
}
|
||||||
|
auto meta = Parse(custom_string);
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
for (const auto &p : meta) {
|
||||||
|
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("T")) {
|
||||||
|
T_ = atoi(meta.at("T").c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("decode_chunk_len")) {
|
||||||
|
decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab_size_ = output_attrs_[0].dims[2];
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
#if __OHOS__
|
||||||
|
SHERPA_ONNX_LOGE("T: %{public}d", T_);
|
||||||
|
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
|
||||||
|
SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size);
|
||||||
|
#else
|
||||||
|
SHERPA_ONNX_LOGE("T: %d", T_);
|
||||||
|
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
||||||
|
SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
if (T_ == 0) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Invalid T. Please use the script from icefall to export your model");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (decode_chunk_len_ == 0) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Invalid decode_chunk_len. Please use the script from icefall to "
|
||||||
|
"export your model");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineModelConfig config_;
|
||||||
|
rknn_context ctx_ = 0;
|
||||||
|
|
||||||
|
std::vector<rknn_tensor_attr> input_attrs_;
|
||||||
|
std::vector<rknn_tensor_attr> output_attrs_;
|
||||||
|
|
||||||
|
int32_t T_ = 0;
|
||||||
|
int32_t decode_chunk_len_ = 0;
|
||||||
|
int32_t vocab_size_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
OnlineZipformerCtcModelRknn::~OnlineZipformerCtcModelRknn() = default;
|
||||||
|
|
||||||
|
OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||||
|
const OnlineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||||
|
Manager *mgr, const OnlineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<OnlineZipformerCtcModelRknn>(mgr, config)) {}
|
||||||
|
|
||||||
|
std::vector<std::vector<uint8_t>> OnlineZipformerCtcModelRknn::GetInitStates()
|
||||||
|
const {
|
||||||
|
return impl_->GetInitStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>>
|
||||||
|
OnlineZipformerCtcModelRknn::Run(
|
||||||
|
std::vector<float> features,
|
||||||
|
std::vector<std::vector<uint8_t>> states) const {
|
||||||
|
return impl_->Run(std::move(features), std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineZipformerCtcModelRknn::ChunkSize() const {
|
||||||
|
return impl_->ChunkSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineZipformerCtcModelRknn::ChunkShift() const {
|
||||||
|
return impl_->ChunkShift();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineZipformerCtcModelRknn::VocabSize() const {
|
||||||
|
return impl_->VocabSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
rknn_tensor_attr OnlineZipformerCtcModelRknn::GetOutAttr() const {
|
||||||
|
return impl_->GetOutAttr();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||||
|
AAssetManager *mgr, const OnlineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn(
|
||||||
|
NativeResourceManager *mgr, const OnlineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
46
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
Normal file
46
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "rknn_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineZipformerCtcModelRknn {
|
||||||
|
public:
|
||||||
|
~OnlineZipformerCtcModelRknn();
|
||||||
|
|
||||||
|
explicit OnlineZipformerCtcModelRknn(const OnlineModelConfig &config);
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
OnlineZipformerCtcModelRknn(Manager *mgr, const OnlineModelConfig &config);
|
||||||
|
|
||||||
|
std::vector<std::vector<uint8_t>> GetInitStates() const;
|
||||||
|
|
||||||
|
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
|
||||||
|
std::vector<float> features,
|
||||||
|
std::vector<std::vector<uint8_t>> states) const;
|
||||||
|
|
||||||
|
int32_t ChunkSize() const;
|
||||||
|
|
||||||
|
int32_t ChunkShift() const;
|
||||||
|
|
||||||
|
int32_t VocabSize() const;
|
||||||
|
|
||||||
|
rknn_tensor_attr GetOutAttr() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
|
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
|
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
|
||||||
|
|
||||||
@@ -22,68 +22,11 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
#include "sherpa-onnx/csrc/rknn/macros.h"
|
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
// chw -> hwc
|
|
||||||
static void Transpose(const float *src, int32_t n, int32_t channel,
|
|
||||||
int32_t height, int32_t width, float *dst) {
|
|
||||||
for (int32_t i = 0; i < n; ++i) {
|
|
||||||
for (int32_t h = 0; h < height; ++h) {
|
|
||||||
for (int32_t w = 0; w < width; ++w) {
|
|
||||||
for (int32_t c = 0; c < channel; ++c) {
|
|
||||||
// dst[h, w, c] = src[c, h, w]
|
|
||||||
dst[i * height * width * channel + h * width * channel + w * channel +
|
|
||||||
c] = src[i * height * width * channel + c * height * width +
|
|
||||||
h * width + w];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string ToString(const rknn_tensor_attr &attr) {
|
|
||||||
std::ostringstream os;
|
|
||||||
os << "{";
|
|
||||||
os << attr.index;
|
|
||||||
os << ", name: " << attr.name;
|
|
||||||
os << ", shape: (";
|
|
||||||
std::string sep;
|
|
||||||
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
|
|
||||||
os << sep << attr.dims[i];
|
|
||||||
sep = ",";
|
|
||||||
}
|
|
||||||
os << ")";
|
|
||||||
os << ", n_elems: " << attr.n_elems;
|
|
||||||
os << ", size: " << attr.size;
|
|
||||||
os << ", fmt: " << get_format_string(attr.fmt);
|
|
||||||
os << ", type: " << get_type_string(attr.type);
|
|
||||||
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
|
|
||||||
os << "}";
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::unordered_map<std::string, std::string> Parse(
|
|
||||||
const rknn_custom_string &custom_string) {
|
|
||||||
std::unordered_map<std::string, std::string> ans;
|
|
||||||
std::vector<std::string> fields;
|
|
||||||
SplitStringToVector(custom_string.string, ";", false, &fields);
|
|
||||||
|
|
||||||
std::vector<std::string> tmp;
|
|
||||||
for (const auto &f : fields) {
|
|
||||||
SplitStringToVector(f, "=", false, &tmp);
|
|
||||||
if (tmp.size() != 2) {
|
|
||||||
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
|
|
||||||
f.c_str());
|
|
||||||
SHERPA_ONNX_EXIT(-1);
|
|
||||||
}
|
|
||||||
ans[std::move(tmp[0])] = std::move(tmp[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ans;
|
|
||||||
}
|
|
||||||
|
|
||||||
class OnlineZipformerTransducerModelRknn::Impl {
|
class OnlineZipformerTransducerModelRknn::Impl {
|
||||||
public:
|
public:
|
||||||
~Impl() {
|
~Impl() {
|
||||||
@@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
|||||||
for (int32_t i = 0; i < next_states.size(); ++i) {
|
for (int32_t i = 0; i < next_states.size(); ++i) {
|
||||||
const auto &attr = encoder_input_attrs_[i + 1];
|
const auto &attr = encoder_input_attrs_[i + 1];
|
||||||
if (attr.n_dims == 4) {
|
if (attr.n_dims == 4) {
|
||||||
// TODO(fangjun): The transpose is copied from
|
// TODO(fangjun): The ConvertNCHWtoNHWC is copied from
|
||||||
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
|
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
|
||||||
// I don't understand why we need to do that.
|
// I don't understand why we need to do that.
|
||||||
std::vector<uint8_t> dst(next_states[i].size());
|
std::vector<uint8_t> dst(next_states[i].size());
|
||||||
@@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
|||||||
int32_t h = attr.dims[1];
|
int32_t h = attr.dims[1];
|
||||||
int32_t w = attr.dims[2];
|
int32_t w = attr.dims[2];
|
||||||
int32_t c = attr.dims[3];
|
int32_t c = attr.dims[3];
|
||||||
Transpose(reinterpret_cast<const float *>(next_states[i].data()), n, c,
|
ConvertNCHWtoNHWC(
|
||||||
h, w, reinterpret_cast<float *>(dst.data()));
|
reinterpret_cast<const float *>(next_states[i].data()), n, c, h, w,
|
||||||
|
reinterpret_cast<float *>(dst.data()));
|
||||||
next_states[i] = std::move(dst);
|
next_states[i] = std::move(dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
|||||||
#if __OHOS__
|
#if __OHOS__
|
||||||
SHERPA_ONNX_LOGE("T: %{public}d", T_);
|
SHERPA_ONNX_LOGE("T: %{public}d", T_);
|
||||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
|
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
|
||||||
SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_);
|
|
||||||
#else
|
#else
|
||||||
SHERPA_ONNX_LOGE("T: %d", T_);
|
SHERPA_ONNX_LOGE("T: %d", T_);
|
||||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
||||||
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
|||||||
SHERPA_ONNX_EXIT(-1);
|
SHERPA_ONNX_EXIT(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
context_size_ = decoder_input_attrs_[0].dims[1];
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
|
||||||
|
}
|
||||||
|
|
||||||
i = 0;
|
i = 0;
|
||||||
for (auto &attr : decoder_output_attrs_) {
|
for (auto &attr : decoder_output_attrs_) {
|
||||||
memset(&attr, 0, sizeof(attr));
|
memset(&attr, 0, sizeof(attr));
|
||||||
|
|||||||
@@ -14,8 +14,11 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
// this is for zipformer v1, i.e., the folder
|
// this is for zipformer v1 and v2, i.e., the folder
|
||||||
// pruned_transducer_statelss7_streaming from icefall
|
// pruned_transducer_statelss7_streaming
|
||||||
|
// and
|
||||||
|
// zipformer
|
||||||
|
// from icefall
|
||||||
class OnlineZipformerTransducerModelRknn {
|
class OnlineZipformerTransducerModelRknn {
|
||||||
public:
|
public:
|
||||||
~OnlineZipformerTransducerModelRknn();
|
~OnlineZipformerTransducerModelRknn();
|
||||||
|
|||||||
73
sherpa-onnx/csrc/rknn/utils.cc
Normal file
73
sherpa-onnx/csrc/rknn/utils.cc
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
// sherpa-onnx/csrc/utils.cc
|
||||||
|
//
|
||||||
|
// Copyright 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
|
||||||
|
int32_t height, int32_t width, float *dst) {
|
||||||
|
for (int32_t i = 0; i < n; ++i) {
|
||||||
|
for (int32_t h = 0; h < height; ++h) {
|
||||||
|
for (int32_t w = 0; w < width; ++w) {
|
||||||
|
for (int32_t c = 0; c < channel; ++c) {
|
||||||
|
// dst[h, w, c] = src[c, h, w]
|
||||||
|
dst[i * height * width * channel + h * width * channel + w * channel +
|
||||||
|
c] = src[i * height * width * channel + c * height * width +
|
||||||
|
h * width + w];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString(const rknn_tensor_attr &attr) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "{";
|
||||||
|
os << attr.index;
|
||||||
|
os << ", name: " << attr.name;
|
||||||
|
os << ", shape: (";
|
||||||
|
std::string sep;
|
||||||
|
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
|
||||||
|
os << sep << attr.dims[i];
|
||||||
|
sep = ",";
|
||||||
|
}
|
||||||
|
os << ")";
|
||||||
|
os << ", n_elems: " << attr.n_elems;
|
||||||
|
os << ", size: " << attr.size;
|
||||||
|
os << ", fmt: " << get_format_string(attr.fmt);
|
||||||
|
os << ", type: " << get_type_string(attr.type);
|
||||||
|
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
|
||||||
|
os << "}";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::string> Parse(
|
||||||
|
const rknn_custom_string &custom_string) {
|
||||||
|
std::unordered_map<std::string, std::string> ans;
|
||||||
|
std::vector<std::string> fields;
|
||||||
|
SplitStringToVector(custom_string.string, ";", false, &fields);
|
||||||
|
|
||||||
|
std::vector<std::string> tmp;
|
||||||
|
for (const auto &f : fields) {
|
||||||
|
SplitStringToVector(f, "=", false, &tmp);
|
||||||
|
if (tmp.size() != 2) {
|
||||||
|
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
|
||||||
|
f.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
ans[std::move(tmp[0])] = std::move(tmp[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
23
sherpa-onnx/csrc/rknn/utils.h
Normal file
23
sherpa-onnx/csrc/rknn/utils.h
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// sherpa-onnx/csrc/utils.h
|
||||||
|
//
|
||||||
|
// Copyright 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "rknn_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel,
|
||||||
|
int32_t height, int32_t width, float *dst);
|
||||||
|
|
||||||
|
std::string ToString(const rknn_tensor_attr &attr);
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::string> Parse(
|
||||||
|
const rknn_custom_string &custom_string);
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_
|
||||||
@@ -83,6 +83,7 @@ for a list of pre-trained models to download.
|
|||||||
po.Read(argc, argv);
|
po.Read(argc, argv);
|
||||||
if (po.NumArgs() < 1) {
|
if (po.NumArgs() < 1) {
|
||||||
po.PrintUsage();
|
po.PrintUsage();
|
||||||
|
fprintf(stderr, "Error! Please provide at lease 1 wav file\n");
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user