diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 382541f9..80804a6f 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -155,6 +155,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) list(APPEND sources ./rknn/online-stream-rknn.cc ./rknn/online-transducer-greedy-search-decoder-rknn.cc + ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc ./rknn/online-zipformer-ctc-model-rknn.cc ./rknn/online-zipformer-transducer-model-rknn.cc ./rknn/utils.cc diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 6a49bad3..428a74fa 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -142,7 +142,6 @@ class Hypotheses { void Clear() { hyps_dict_.clear(); } - private: // Return a list of hyps contained in this object. std::vector Vec() const { std::vector ans; diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h index 121a05ae..21fd3803 100644 --- a/sherpa-onnx/csrc/math.h +++ b/sherpa-onnx/csrc/math.h @@ -119,5 +119,17 @@ std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { return {vec_index.begin(), vec_index.begin() + k_num}; } +template +std::vector TopkIndex(const std::vector> &vec, + int32_t topk) { + std::vector flatten; + flatten.reserve(vec.size() * vec[0].size()); + for (const auto &v : vec) { + flatten.insert(flatten.end(), v.begin(), v.end()); + } + + return TopkIndex(flatten.data(), flatten.size(), topk); +} + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_MATH_H_ diff --git a/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h b/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h index 9cc29507..c3df5b80 100644 --- a/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h +++ b/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h @@ -16,7 +16,9 @@ #include "sherpa-onnx/csrc/online-recognizer-impl.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" #include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h" #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" #include "sherpa-onnx/csrc/symbol-table.h" @@ -87,8 +89,20 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { unk_id_ = sym_[""]; } - decoder_ = std::make_unique( - model_.get(), unk_id_); + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), unk_id_); + } else if (config.decoding_method == "modified_beam_search") { + decoder_ = + std::make_unique( + model_.get(), config.max_active_paths, unk_id_); + } else { + SHERPA_ONNX_LOGE( + "Invalid decoding method: '%s'. Support only greedy_search and " + "modified_beam_search.", + config.decoding_method.c_str()); + SHERPA_ONNX_EXIT(-1); + } } template @@ -223,7 +237,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { Endpoint endpoint_; int32_t unk_id_ = -1; std::unique_ptr model_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/online-stream-rknn.h b/sherpa-onnx/csrc/rknn/online-stream-rknn.h index fe249d5b..780b71ec 100644 --- a/sherpa-onnx/csrc/rknn/online-stream-rknn.h +++ b/sherpa-onnx/csrc/rknn/online-stream-rknn.h @@ -8,7 +8,7 @@ #include "rknn_api.h" // NOLINT #include "sherpa-onnx/csrc/online-stream.h" -#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" namespace sherpa_onnx { diff --git a/sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h b/sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h new file mode 100644 index 00000000..81284f0e --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h @@ -0,0 +1,63 @@ +// sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_ + +#include + +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +struct OnlineTransducerDecoderResultRknn { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + + /// The decoded token IDs so far + std::vector tokens; + + /// number of trailing blank frames decoded so far + int32_t num_trailing_blanks = 0; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + std::vector timestamps; + + // used only by greedy_search + std::vector previous_decoder_out; + + // used only in modified beam_search + Hypotheses hyps; + + // used only by modified_beam_search + std::vector> previous_decoder_out2; +}; + +class OnlineTransducerDecoderRknn { + public: + virtual ~OnlineTransducerDecoderRknn() = default; + + /* Return an empty result. + * + * To simplify the decoding code, we add `context_size` blanks + * to the beginning of the decoding result, which will be + * stripped by calling `StripPrecedingBlanks()`. + */ + virtual OnlineTransducerDecoderResultRknn GetEmptyResult() const = 0; + + /** Strip blanks added by `GetEmptyResult()`. + * + * @param r It is changed in-place. + */ + virtual void StripLeadingBlanks( + OnlineTransducerDecoderResultRknn * /*r*/) const {} + + virtual void Decode(std::vector encoder_out, + OnlineTransducerDecoderResultRknn *result) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_DECODER_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h b/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h index 6699b75f..b3d12a60 100644 --- a/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h +++ b/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h @@ -7,39 +7,26 @@ #include +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" namespace sherpa_onnx { -struct OnlineTransducerDecoderResultRknn { - /// Number of frames after subsampling we have decoded so far - int32_t frame_offset = 0; - - /// The decoded token IDs so far - std::vector tokens; - - /// number of trailing blank frames decoded so far - int32_t num_trailing_blanks = 0; - - /// timestamps[i] contains the output frame index where tokens[i] is decoded. - std::vector timestamps; - - std::vector previous_decoder_out; -}; - -class OnlineTransducerGreedySearchDecoderRknn { +class OnlineTransducerGreedySearchDecoderRknn + : public OnlineTransducerDecoderRknn { public: explicit OnlineTransducerGreedySearchDecoderRknn( OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2, float blank_penalty = 0.0) : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} - OnlineTransducerDecoderResultRknn GetEmptyResult() const; + OnlineTransducerDecoderResultRknn GetEmptyResult() const override; - void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const; + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const override; void Decode(std::vector encoder_out, - OnlineTransducerDecoderResultRknn *result) const; + OnlineTransducerDecoderResultRknn *result) const override; private: OnlineZipformerTransducerModelRknn *model_; // Not owned diff --git a/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc b/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc new file mode 100644 index 00000000..cb4456b2 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc @@ -0,0 +1,146 @@ +// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/math.h" + +namespace sherpa_onnx { + +OnlineTransducerDecoderResultRknn +OnlineTransducerModifiedBeamSearchDecoderRknn::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResultRknn r; + + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + + Hypotheses blank_hyp({{blanks, 0}}); + r.hyps = std::move(blank_hyp); + r.tokens = std::move(blanks); + + return r; +} + +void OnlineTransducerModifiedBeamSearchDecoderRknn::StripLeadingBlanks( + OnlineTransducerDecoderResultRknn *r) const { + int32_t context_size = model_->ContextSize(); + auto hyp = r->hyps.GetMostProbable(true); + + std::vector tokens(hyp.ys.begin() + context_size, hyp.ys.end()); + r->tokens = std::move(tokens); + r->timestamps = std::move(hyp.timestamps); + + r->num_trailing_blanks = hyp.num_trailing_blanks; +} + +static std::vector> GetDecoderOut( + OnlineZipformerTransducerModelRknn *model, const Hypotheses &hyp_vec) { + std::vector> ans; + ans.reserve(hyp_vec.Size()); + + int32_t context_size = model->ContextSize(); + for (const auto &p : hyp_vec) { + const auto &hyp = p.second; + auto start = hyp.ys.begin() + (hyp.ys.size() - context_size); + auto end = hyp.ys.end(); + auto tokens = std::vector(start, end); + auto decoder_out = model->RunDecoder(std::move(tokens)); + + ans.push_back(std::move(decoder_out)); + } + + return ans; +} + +static std::vector> GetJoinerOutLogSoftmax( + OnlineZipformerTransducerModelRknn *model, const float *p_encoder_out, + const std::vector> &decoder_out) { + std::vector> ans; + ans.reserve(decoder_out.size()); + + for (const auto &d : decoder_out) { + auto joiner_out = model->RunJoiner(p_encoder_out, d.data()); + + LogSoftmax(joiner_out.data(), joiner_out.size()); + + ans.push_back(std::move(joiner_out)); + } + return ans; +} + +void OnlineTransducerModifiedBeamSearchDecoderRknn::Decode( + std::vector encoder_out, + OnlineTransducerDecoderResultRknn *result) const { + auto &r = result[0]; + auto attr = model_->GetEncoderOutAttr(); + int32_t num_frames = attr.dims[1]; + int32_t encoder_out_dim = attr.dims[2]; + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + Hypotheses cur = std::move(result->hyps); + std::vector prev; + + auto decoder_out = std::move(result->previous_decoder_out2); + if (decoder_out.empty()) { + decoder_out = GetDecoderOut(model_, cur); + } + + const float *p_encoder_out = encoder_out.data(); + + int32_t frame_offset = result->frame_offset; + + for (int32_t t = 0; t != num_frames; ++t) { + prev = cur.Vec(); + cur.Clear(); + + auto log_probs = GetJoinerOutLogSoftmax(model_, p_encoder_out, decoder_out); + p_encoder_out += encoder_out_dim; + + for (int32_t i = 0; i != prev.size(); ++i) { + auto log_prob = prev[i].log_prob; + for (auto &p : log_probs[i]) { + p += log_prob; + } + } + + auto topk = TopkIndex(log_probs, max_active_paths_); + for (auto k : topk) { + int32_t hyp_index = k / vocab_size; + int32_t new_token = k % vocab_size; + + Hypothesis new_hyp = prev[hyp_index]; + new_hyp.log_prob = log_probs[hyp_index][new_token]; + + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { + new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t + frame_offset); + new_hyp.num_trailing_blanks = 0; + + } else { + ++new_hyp.num_trailing_blanks; + } + cur.Add(std::move(new_hyp)); + } + + decoder_out = GetDecoderOut(model_, cur); + } + + result->hyps = std::move(cur); + result->frame_offset += num_frames; + result->previous_decoder_out2 = std::move(decoder_out); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h b/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h new file mode 100644 index 00000000..31f8907a --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_ + +#include + +#include "sherpa-onnx/csrc/rknn/online-transducer-decoder-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" + +namespace sherpa_onnx { + +class OnlineTransducerModifiedBeamSearchDecoderRknn + : public OnlineTransducerDecoderRknn { + public: + explicit OnlineTransducerModifiedBeamSearchDecoderRknn( + OnlineZipformerTransducerModelRknn *model, int32_t max_active_paths, + int32_t unk_id = 2, float blank_penalty = 0.0) + : model_(model), + max_active_paths_(max_active_paths), + unk_id_(unk_id), + blank_penalty_(blank_penalty) {} + + OnlineTransducerDecoderResultRknn GetEmptyResult() const override; + + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const override; + + void Decode(std::vector encoder_out, + OnlineTransducerDecoderResultRknn *result) const override; + + private: + OnlineZipformerTransducerModelRknn *model_; // Not owned + int32_t max_active_paths_; + int32_t unk_id_; + float blank_penalty_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/utils.cc b/sherpa-onnx/csrc/rknn/utils.cc index e1581ed4..165bf096 100644 --- a/sherpa-onnx/csrc/rknn/utils.cc +++ b/sherpa-onnx/csrc/rknn/utils.cc @@ -6,6 +6,7 @@ #include #include +#include #include #include "sherpa-onnx/csrc/macros.h"