From 5326d0f81f82ba7785f75bc3866c26b97611dba3 Mon Sep 17 00:00:00 2001 From: PF Luo Date: Wed, 1 Mar 2023 15:32:54 +0800 Subject: [PATCH] add modified beam search (#69) --- .gitignore | 1 + sherpa-onnx/csrc/CMakeLists.txt | 2 + sherpa-onnx/csrc/hypothesis.cc | 65 ++++++++ sherpa-onnx/csrc/hypothesis.h | 117 +++++++++++++ sherpa-onnx/csrc/math.h | 107 ++++++++++++ .../csrc/online-lstm-transducer-model.cc | 18 -- .../csrc/online-lstm-transducer-model.h | 3 - sherpa-onnx/csrc/online-recognizer.cc | 37 ++++- sherpa-onnx/csrc/online-recognizer.h | 6 +- sherpa-onnx/csrc/online-transducer-decoder.h | 4 + ...online-transducer-greedy-search-decoder.cc | 38 +---- sherpa-onnx/csrc/online-transducer-model.cc | 36 ++++ sherpa-onnx/csrc/online-transducer-model.h | 11 +- ...transducer-modified-beam-search-decoder.cc | 154 ++++++++++++++++++ ...-transducer-modified-beam-search-decoder.h | 37 +++++ .../csrc/online-zipformer-transducer-model.cc | 18 -- .../csrc/online-zipformer-transducer-model.h | 3 - sherpa-onnx/csrc/onnx-utils.cc | 32 ++++ sherpa-onnx/csrc/onnx-utils.h | 12 ++ 19 files changed, 614 insertions(+), 87 deletions(-) create mode 100644 sherpa-onnx/csrc/hypothesis.cc create mode 100644 sherpa-onnx/csrc/hypothesis.h create mode 100644 sherpa-onnx/csrc/math.h create mode 100644 sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc create mode 100644 sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h diff --git a/.gitignore b/.gitignore index b3ea7a9a..1cdb6323 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ decode-file tokens.txt *.onnx log.txt +tags diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 23044355..b3ff7189 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -5,11 +5,13 @@ set(sources endpoint.cc features.cc file-utils.cc + hypothesis.cc online-lstm-transducer-model.cc online-recognizer.cc online-stream.cc online-transducer-greedy-search-decoder.cc online-transducer-model-config.cc + online-transducer-modified-beam-search-decoder.cc online-transducer-model.cc online-zipformer-transducer-model.cc onnx-utils.cc diff --git a/sherpa-onnx/csrc/hypothesis.cc b/sherpa-onnx/csrc/hypothesis.cc new file mode 100644 index 00000000..174e2880 --- /dev/null +++ b/sherpa-onnx/csrc/hypothesis.cc @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2023 Xiaomi Corporation + * + */ + +#include "sherpa-onnx/csrc/hypothesis.h" + +#include +#include + +namespace sherpa_onnx { + +void Hypotheses::Add(Hypothesis hyp) { + auto key = hyp.Key(); + auto it = hyps_dict_.find(key); + if (it == hyps_dict_.end()) { + hyps_dict_[key] = std::move(hyp); + } else { + it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + } +} + +Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { + if (length_norm == false) { + return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), + [](const auto &left, auto &right) -> bool { + return left.second.log_prob < + right.second.log_prob; + }) + ->second; + } else { + // for length_norm is true + return std::max_element( + hyps_dict_.begin(), hyps_dict_.end(), + [](const auto &left, const auto &right) -> bool { + return left.second.log_prob / left.second.ys.size() < + right.second.log_prob / right.second.ys.size(); + }) + ->second; + } +} + +std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { + k = std::max(k, 1); + k = std::min(k, Size()); + + std::vector all_hyps = Vec(); + + if (length_norm == false) { + std::partial_sort( + all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), + [](const auto &a, const auto &b) { return a.log_prob > b.log_prob; }); + } else { + // for length_norm is true + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), + [](const auto &a, const auto &b) { + return a.log_prob / a.ys.size() > + b.log_prob / b.ys.size(); + }); + } + + return {all_hyps.begin(), all_hyps.begin() + k}; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h new file mode 100644 index 00000000..6023af8b --- /dev/null +++ b/sherpa-onnx/csrc/hypothesis.h @@ -0,0 +1,117 @@ +/** + * Copyright (c) 2023 Xiaomi Corporation + * + */ + +#ifndef SHERPA_ONNX_CSRC_HYPOTHESIS_H_ +#define SHERPA_ONNX_CSRC_HYPOTHESIS_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/math.h" + +namespace sherpa_onnx { + +struct Hypothesis { + // The predicted tokens so far. Newly predicated tokens are appended. + std::vector ys; + + // timestamps[i] contains the frame number after subsampling + // on which ys[i] is decoded. + std::vector timestamps; + + // The total score of ys in log space. + double log_prob = 0; + + int32_t num_trailing_blanks = 0; + + Hypothesis() = default; + Hypothesis(const std::vector &ys, double log_prob) + : ys(ys), log_prob(log_prob) {} + + // If two Hypotheses have the same `Key`, then they contain + // the same token sequence. + std::string Key() const { + // TODO(fangjun): Use a hash function? + std::ostringstream os; + std::string sep = "-"; + for (auto i : ys) { + os << i << sep; + sep = "-"; + } + return os.str(); + } + + // For debugging + std::string ToString() const { + std::ostringstream os; + os << "(" << Key() << ", " << log_prob << ")"; + return os.str(); + } +}; + +class Hypotheses { + public: + Hypotheses() = default; + + explicit Hypotheses(std::vector hyps) { + for (auto &h : hyps) { + hyps_dict_[h.Key()] = std::move(h); + } + } + + explicit Hypotheses(std::unordered_map hyps_dict) + : hyps_dict_(std::move(hyps_dict)) {} + + // Add hyp to this object. If it already exists, its log_prob + // is updated with the given hyp using log-sum-exp. + void Add(Hypothesis hyp); + + // Get the hyp that has the largest log_prob. + // If length_norm is true, hyp's log_prob is divided by + // len(hyp.ys) before comparison. + Hypothesis GetMostProbable(bool length_norm) const; + + // Get the k hyps that have the largest log_prob. + // If length_norm is true, hyp's log_prob is divided by + // len(hyp.ys) before comparison. + std::vector GetTopK(int32_t k, bool length_norm) const; + + int32_t Size() const { return hyps_dict_.size(); } + + std::string ToString() const { + std::ostringstream os; + for (const auto &p : hyps_dict_) { + os << p.second.ToString() << "\n"; + } + return os.str(); + } + + const auto begin() const { return hyps_dict_.begin(); } + const auto end() const { return hyps_dict_.end(); } + + void Clear() { hyps_dict_.clear(); } + + private: + // Return a list of hyps contained in this object. + std::vector Vec() const { + std::vector ans; + ans.reserve(hyps_dict_.size()); + for (const auto &p : hyps_dict_) { + ans.push_back(p.second); + } + return ans; + } + + private: + using Map = std ::unordered_map; + Map hyps_dict_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_ diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h new file mode 100644 index 00000000..c3345657 --- /dev/null +++ b/sherpa-onnx/csrc/math.h @@ -0,0 +1,107 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey) + * Copyright (c) 2023 (Pingfeng Luo) + * + */ +// This file is copied from k2/csrc/utils.h +#ifndef SHERPA_ONNX_CSRC_MATH_H_ +#define SHERPA_ONNX_CSRC_MATH_H_ + +#include +#include +#include +#include +#include + +namespace sherpa_onnx { + +// logf(FLT_EPSILON) +#define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f + +// log(DBL_EPSILON) +#define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \ + -36.0436533891171535515240975655615329742431640625 + +template +struct LogAdd; + +template <> +struct LogAdd { + double operator()(double x, double y) const { + double diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { + double res; + res = x + log1p(exp(diff)); + return res; + } + + return x; // return the larger one. + } +}; + +template <> +struct LogAdd { + float operator()(float x, float y) const { + float diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { + float res; + res = x + log1pf(expf(diff)); + return res; + } + + return x; // return the larger one. + } +}; + +template +void LogSoftmax(T *input, int32_t input_len) { + assert(input); + + T m = *std::max_element(input, input + input_len); + + T sum = 0.0; + for (int32_t i = 0; i < input_len; i++) { + sum += exp(input[i] - m); + } + + T offset = m + log(sum); + for (int32_t i = 0; i < input_len; i++) { + input[i] -= offset; + } +} + +template +std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { + std::vector vec_index(size); + std::iota(vec_index.begin(), vec_index.end(), 0); + + std::sort(vec_index.begin(), vec_index.end(), + [vec](int32_t index_1, int32_t index_2) { + return vec[index_1] > vec[index_2]; + }); + + int32_t k_num = std::min(size, topk); + std::vector index(vec_index.begin(), vec_index.begin() + k_num); + return index; +} + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_MATH_H_ diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 7d32efed..29972b19 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -247,24 +247,6 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, return {std::move(encoder_out[0]), std::move(next_states)}; } -Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( - const std::vector &results) { - int32_t batch_size = static_cast(results.size()); - std::array shape{batch_size, context_size_}; - Ort::Value decoder_input = - Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); - int64_t *p = decoder_input.GetTensorMutableData(); - - for (const auto &r : results) { - const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; - const int64_t *end = r.tokens.data() + r.tokens.size(); - std::copy(begin, end, p); - p += context_size_; - } - - return decoder_input; -} - Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { auto decoder_out = decoder_sess_->Run( {}, decoder_input_names_ptr_.data(), &decoder_input, 1, diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index a73912a7..5b6ad282 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -40,9 +40,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { std::pair> RunEncoder( Ort::Value features, std::vector states) override; - Ort::Value BuildDecoderInput( - const std::vector &results) override; - Ort::Value RunDecoder(Ort::Value decoder_input) override; Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 7c54b6cd..e85f4574 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/online-recognizer.cc // // Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo #include "sherpa-onnx/csrc/online-recognizer.h" @@ -16,6 +17,7 @@ #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model.h" +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" #include "sherpa-onnx/csrc/symbol-table.h" namespace sherpa_onnx { @@ -39,6 +41,11 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { po->Register("enable-endpoint", &enable_endpoint, "True to enable endpoint detection. False to disable it."); + po->Register("max-active-paths", &max_active_paths, + "beam size used in modified beam search."); + po->Register("decoding-mothod", &decoding_method, + "decoding method," + "now support greedy_search and modified_beam_search."); } bool OnlineRecognizerConfig::Validate() const { @@ -52,7 +59,9 @@ std::string OnlineRecognizerConfig::ToString() const { os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", "; - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")"; + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ","; + os << "max_active_paths=" << max_active_paths << ","; + os << "decoding_method=\"" << decoding_method << "\")"; return os.str(); } @@ -64,8 +73,17 @@ class OnlineRecognizer::Impl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { - decoder_ = - std::make_unique(model_.get()); + if (config.decoding_method == "modified_beam_search") { + decoder_ = std::make_unique( + model_.get(), config_.max_active_paths); + } else if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + fprintf(stderr, "Unsupported decoding method: %s\n", + config.decoding_method.c_str()); + exit(-1); + } } #if __ANDROID_API__ >= 9 @@ -74,8 +92,17 @@ class OnlineRecognizer::Impl { model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { - decoder_ = - std::make_unique(model_.get()); + if (config.decoding_method == "modified_beam_search") { + decoder_ = std::make_unique( + model_.get(), config_.max_active_paths); + } else if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + fprintf(stderr, "Unsupported decoding method: %s\n", + config.decoding_method.c_str()); + exit(-1); + } } #endif diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index cceadca6..d03b1795 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -32,7 +32,11 @@ struct OnlineRecognizerConfig { FeatureExtractorConfig feat_config; OnlineTransducerModelConfig model_config; EndpointConfig endpoint_config; - bool enable_endpoint; + bool enable_endpoint = true; + int32_t max_active_paths = 4; + + std::string decoding_method = "modified_beam_search"; + // now support modified_beam_search and greedy_search OnlineRecognizerConfig() = default; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 6b8eb4cc..c70afc30 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -8,6 +8,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/hypothesis.h" namespace sherpa_onnx { @@ -17,6 +18,9 @@ struct OnlineTransducerDecoderResult { /// number of trailing blank frames decoded so far int32_t num_trailing_blanks = 0; + + // used only in modified beam_search + Hypotheses hyps; }; class OnlineTransducerDecoder { diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 007c596d..5e194f3d 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -4,8 +4,6 @@ #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" -#include - #include #include #include @@ -15,39 +13,6 @@ namespace sherpa_onnx { -static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out, - int32_t t) { - std::vector encoder_out_shape = - encoder_out->GetTensorTypeAndShapeInfo().GetShape(); - - auto batch_size = encoder_out_shape[0]; - auto num_frames = encoder_out_shape[1]; - assert(t < num_frames); - - auto encoder_out_dim = encoder_out_shape[2]; - - auto offset = num_frames * encoder_out_dim; - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - std::array shape{batch_size, encoder_out_dim}; - - Ort::Value ans = - Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); - - float *dst = ans.GetTensorMutableData(); - const float *src = encoder_out->GetTensorData(); - - for (int32_t i = 0; i != batch_size; ++i) { - std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); - src += offset; - dst += encoder_out_dim; - } - - return ans; -} - OnlineTransducerDecoderResult OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); @@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); for (int32_t t = 0; t != num_frames; ++t) { - Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t); + Ort::Value cur_encoder_out = + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); Ort::Value logit = model_->RunJoiner( std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index effbf558..3a40e9c6 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/online-transducer-model.cc // // Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo #include "sherpa-onnx/csrc/online-transducer-model.h" #if __ANDROID_API__ >= 9 @@ -8,6 +9,7 @@ #include "android/asset_manager_jni.h" #endif +#include #include #include #include @@ -75,6 +77,40 @@ std::unique_ptr OnlineTransducerModel::Create( return nullptr; } +Ort::Value OnlineTransducerModel::BuildDecoderInput( + const std::vector &results) { + int32_t batch_size = static_cast(results.size()); + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + Ort::Value decoder_input = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + int64_t *p = decoder_input.GetTensorMutableData(); + + for (const auto &r : results) { + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size; + const int64_t *end = r.tokens.data() + r.tokens.size(); + std::copy(begin, end, p); + p += context_size; + } + return decoder_input; +} + +Ort::Value OnlineTransducerModel::BuildDecoderInput( + const std::vector &hyps) { + int32_t batch_size = static_cast(hyps.size()); + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + Ort::Value decoder_input = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + int64_t *p = decoder_input.GetTensorMutableData(); + + for (const auto &h : hyps) { + std::copy(h.ys.end() - context_size, h.ys.end(), p); + p += context_size; + } + return decoder_input; +} + #if __ANDROID_API__ >= 9 std::unique_ptr OnlineTransducerModel::Create( AAssetManager *mgr, const OnlineTransducerModelConfig &config) { diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index 2757024c..42e7948b 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -14,6 +14,8 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" namespace sherpa_onnx { @@ -71,9 +73,6 @@ class OnlineTransducerModel { Ort::Value features, std::vector states) = 0; // NOLINT - virtual Ort::Value BuildDecoderInput( - const std::vector &results) = 0; - /** Run the decoder network. * * Caution: We assume there are no recurrent connections in the decoder and @@ -125,7 +124,13 @@ class OnlineTransducerModel { virtual int32_t VocabSize() const = 0; virtual int32_t SubsamplingFactor() const { return 4; } + virtual OrtAllocator *Allocator() = 0; + + Ort::Value BuildDecoderInput( + const std::vector &results); + + Ort::Value BuildDecoderInput(const std::vector &hyps); }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc new file mode 100644 index 00000000..eab279c5 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -0,0 +1,154 @@ +// sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, + const std::vector &hyps_num_split) { + std::vector cur_encoder_out_shape = + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); + + std::array ans_shape{hyps_num_split.back(), + cur_encoder_out_shape[1]}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + + const float *src = cur_encoder_out->GetTensorData(); + float *dst = ans.GetTensorMutableData(); + int32_t batch_size = static_cast(hyps_num_split.size()) - 1; + for (int32_t b = 0; b != batch_size; ++b) { + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { + std::copy(src, src + cur_encoder_out_shape[1], dst); + dst += cur_encoder_out_shape[1]; + } + src += cur_encoder_out_shape[1]; + } + return ans; +} + +static void LogSoftmax(float *in, int32_t w, int32_t h) { + for (int32_t i = 0; i != h; ++i) { + LogSoftmax(in, w); + in += w; + } +} + +OnlineTransducerDecoderResult +OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResult r; + std::vector blanks(context_size, blank_id); + Hypotheses blank_hyp({{blanks, 0}}); + r.hyps = std::move(blank_hyp); + return r; +} + +void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( + OnlineTransducerDecoderResult *r) const { + int32_t context_size = model_->ContextSize(); + auto hyp = r->hyps.GetMostProbable(true); + + std::vector tokens(hyp.ys.begin() + context_size, hyp.ys.end()); + r->tokens = std::move(tokens); + r->num_trailing_blanks = hyp.num_trailing_blanks; +} + +void OnlineTransducerModifiedBeamSearchDecoder::Decode( + Ort::Value encoder_out, + std::vector *result) { + std::vector encoder_out_shape = + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + if (encoder_out_shape[0] != result->size()) { + fprintf(stderr, + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); + exit(-1); + } + + int32_t batch_size = static_cast(encoder_out_shape[0]); + int32_t num_frames = static_cast(encoder_out_shape[1]); + int32_t vocab_size = model_->VocabSize(); + + std::vector cur; + for (auto &r : *result) { + cur.push_back(std::move(r.hyps)); + } + std::vector prev; + + for (int32_t t = 0; t != num_frames; ++t) { + // Due to merging paths with identical token sequences, + // not all utterances have "num_active_paths" paths. + int32_t hyps_num_acc = 0; + std::vector hyps_num_split; + hyps_num_split.push_back(0); + + prev.clear(); + for (auto &hyps : cur) { + for (auto &h : hyps) { + prev.push_back(std::move(h.second)); + hyps_num_acc++; + } + hyps_num_split.push_back(hyps_num_acc); + } + cur.clear(); + cur.reserve(batch_size); + + Ort::Value decoder_input = model_->BuildDecoderInput(prev); + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); + + Ort::Value cur_encoder_out = + GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); + cur_encoder_out = + Repeat(model_->Allocator(), &cur_encoder_out, hyps_num_split); + Ort::Value logit = model_->RunJoiner( + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + float *p_logit = logit.GetTensorMutableData(); + + for (int32_t b = 0; b < batch_size; ++b) { + int32_t start = hyps_num_split[b]; + int32_t end = hyps_num_split[b + 1]; + LogSoftmax(p_logit, vocab_size, (end - start)); + auto topk = + TopkIndex(p_logit, vocab_size * (end - start), max_active_paths_); + + Hypotheses hyps; + for (auto i : topk) { + int32_t hyp_index = i / vocab_size + start; + int32_t new_token = i % vocab_size; + + Hypothesis new_hyp = prev[hyp_index]; + if (new_token != 0) { + new_hyp.ys.push_back(new_token); + new_hyp.num_trailing_blanks = 0; + } else { + ++new_hyp.num_trailing_blanks; + } + new_hyp.log_prob += p_logit[i]; + hyps.Add(std::move(new_hyp)); + } + cur.push_back(std::move(hyps)); + p_logit += vocab_size * (end - start); + } + } + + for (int32_t b = 0; b != batch_size; ++b) { + (*result)[b].hyps = std::move(cur[b]); + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h new file mode 100644 index 00000000..f1443539 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/online-transducer-modified_beam-search-decoder.h +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineTransducerModifiedBeamSearchDecoder + : public OnlineTransducerDecoder { + public: + OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, + int32_t max_active_paths) + : model_(model), max_active_paths_(max_active_paths) {} + + OnlineTransducerDecoderResult GetEmptyResult() const override; + + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override; + + void Decode(Ort::Value encoder_out, + std::vector *result) override; + + private: + OnlineTransducerModel *model_; // Not owned + int32_t max_active_paths_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 7038274e..c0748179 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, return {std::move(encoder_out[0]), std::move(next_states)}; } -Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput( - const std::vector &results) { - int32_t batch_size = static_cast(results.size()); - std::array shape{batch_size, context_size_}; - Ort::Value decoder_input = - Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); - int64_t *p = decoder_input.GetTensorMutableData(); - - for (const auto &r : results) { - const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; - const int64_t *end = r.tokens.data() + r.tokens.size(); - std::copy(begin, end, p); - p += context_size_; - } - - return decoder_input; -} - Ort::Value OnlineZipformerTransducerModel::RunDecoder( Ort::Value decoder_input) { auto decoder_out = decoder_sess_->Run( diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.h b/sherpa-onnx/csrc/online-zipformer-transducer-model.h index 02a9742d..c2f237a3 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.h @@ -41,9 +41,6 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { std::pair> RunEncoder( Ort::Value features, std::vector states) override; - Ort::Value BuildDecoderInput( - const std::vector &results) override; - Ort::Value RunDecoder(Ort::Value decoder_input) override; Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 664aac03..8357bc31 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector *output_names, } } +Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, + int32_t t) { + std::vector encoder_out_shape = + encoder_out->GetTensorTypeAndShapeInfo().GetShape(); + + auto batch_size = encoder_out_shape[0]; + auto num_frames = encoder_out_shape[1]; + assert(t < num_frames); + + auto encoder_out_dim = encoder_out_shape[2]; + + auto offset = num_frames * encoder_out_dim; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape{batch_size, encoder_out_dim}; + + Ort::Value ans = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + float *dst = ans.GetTensorMutableData(); + const float *src = encoder_out->GetTensorData(); + + for (int32_t i = 0; i != batch_size; ++i) { + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); + src += offset; + dst += encoder_out_dim; + } + return ans; +} + void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { Ort::AllocatorWithDefaultOptions allocator; std::vector v = diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 1ac846f3..af4f3ccb 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -10,6 +10,7 @@ #include #endif +#include #include #include #include @@ -57,6 +58,17 @@ void GetInputNames(Ort::Session *sess, std::vector *input_names, void GetOutputNames(Ort::Session *sess, std::vector *output_names, std::vector *output_names_ptr); +/** + * Get the output frame of Encoder + * + * @param allocator allocator of onnxruntime + * @param encoder_out encoder out tensor + * @param t frame_index + * + */ +Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, + int32_t t); + void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data); // NOLINT