Support RKNN for Zipformer CTC models. (#1948)

This commit is contained in:
Fangjun Kuang
2025-03-02 21:40:13 +08:00
committed by GitHub
parent dfcbc8d40b
commit d5e7b51af5
17 changed files with 819 additions and 114 deletions

View File

@@ -24,12 +24,11 @@
namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start) {
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor, int32_t segment,
int32_t frames_since_start) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
@@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(std::move(out_states));
decoder_->Decode(std::move(out[0]), &results, ss, n);
std::vector<int64_t> log_probs_shape =
out[0].GetTensorTypeAndShapeInfo().GetShape();
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
log_probs_shape[1], log_probs_shape[2], &results, ss, n);
for (int32_t k = 0; k != n; ++k) {
ss[k]->SetCtcResult(results[k]);
@@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
auto r =
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
return r;
}
@@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<OnlineCtcDecoderResult> results(1);
results[0] = std::move(s->GetCtcResult());
decoder_->Decode(std::move(out[0]), &results, &s, 1);
std::vector<int64_t> log_probs_shape =
out[0].GetTensorTypeAndShapeInfo().GetShape();
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
log_probs_shape[1], log_probs_shape[2], &results, &s, 1);
s->SetCtcResult(results[0]);
}