diff --git a/.gitignore b/.gitignore index f28802d8..6971f74a 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,4 @@ run-offline-decode-files.sh sherpa-onnx-nemo-ctc-en-citrinet-512 run-offline-decode-files-nemo-ctc.sh *.jar +sherpa-onnx-nemo-ctc-* diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e2e4293..5330ecfa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI) set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE) endif() +if(SHERPA_ONNX_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS) + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_PYTHON is ON") + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) +endif() + if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS) message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON") set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 8a50f770..883eff96 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -18,6 +18,8 @@ set(sources hypothesis.cc offline-ctc-greedy-search-decoder.cc offline-ctc-model.cc + offline-lm-config.cc + offline-lm.cc offline-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc offline-nemo-enc-dec-ctc-model.cc @@ -26,10 +28,13 @@ set(sources offline-paraformer-model.cc offline-recognizer-impl.cc offline-recognizer.cc + offline-rnn-lm.cc offline-stream.cc offline-transducer-greedy-search-decoder.cc offline-transducer-model-config.cc offline-transducer-model.cc + offline-transducer-modified-beam-search-decoder.cc + online-lm-config.cc online-lstm-transducer-model.cc online-recognizer.cc online-stream.cc diff --git a/sherpa-onnx/csrc/hypothesis.cc b/sherpa-onnx/csrc/hypothesis.cc index 174e2880..5c2693d7 100644 --- a/sherpa-onnx/csrc/hypothesis.cc +++ b/sherpa-onnx/csrc/hypothesis.cc @@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) { hyps_dict_[key] = std::move(hyp); } else { it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + + it->second.lm_log_prob = + LogAdd()(it->second.lm_log_prob, hyp.lm_log_prob); } } @@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { if (length_norm == false) { return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), [](const auto &left, auto &right) -> bool { - return left.second.log_prob < - right.second.log_prob; + return left.second.TotalLogProb() < + right.second.TotalLogProb(); }) ->second; } else { @@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { return std::max_element( hyps_dict_.begin(), hyps_dict_.end(), [](const auto &left, const auto &right) -> bool { - return left.second.log_prob / left.second.ys.size() < - right.second.log_prob / right.second.ys.size(); + return left.second.TotalLogProb() / left.second.ys.size() < + right.second.TotalLogProb() / right.second.ys.size(); }) ->second; } @@ -47,15 +50,16 @@ std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { std::vector all_hyps = Vec(); if (length_norm == false) { - std::partial_sort( - all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), - [](const auto &a, const auto &b) { return a.log_prob > b.log_prob; }); + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), + [](const auto &a, const auto &b) { + return a.TotalLogProb() > b.TotalLogProb(); + }); } else { // for length_norm is true std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), [](const auto &a, const auto &b) { - return a.log_prob / a.ys.size() > - b.log_prob / b.ys.size(); + return a.TotalLogProb() / a.ys.size() > + b.TotalLogProb() / b.ys.size(); }); } diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 725dba07..86221729 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -25,14 +25,20 @@ struct Hypothesis { std::vector timestamps; // The total score of ys in log space. + // It contains only acoustic scores double log_prob = 0; + // LM log prob if any. + double lm_log_prob = 0; + int32_t num_trailing_blanks = 0; Hypothesis() = default; Hypothesis(const std::vector &ys, double log_prob) : ys(ys), log_prob(log_prob) {} + double TotalLogProb() const { return log_prob + lm_log_prob; } + // If two Hypotheses have the same `Key`, then they contain // the same token sequence. std::string Key() const { @@ -94,6 +100,9 @@ class Hypotheses { const auto begin() const { return hyps_dict_.begin(); } const auto end() const { return hyps_dict_.end(); } + auto begin() { return hyps_dict_.begin(); } + auto end() { return hyps_dict_.end(); } + void Clear() { hyps_dict_.clear(); } private: diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h index c3345657..d63c3ed3 100644 --- a/sherpa-onnx/csrc/math.h +++ b/sherpa-onnx/csrc/math.h @@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) { } } +template +void LogSoftmax(T *in, int32_t w, int32_t h) { + for (int32_t i = 0; i != h; ++i) { + LogSoftmax(in, w); + in += w; + } +} + +// TODO(fangjun): use std::partial_sort to replace std::sort. +// Remember also to fix sherpa-ncnn template std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { std::vector vec_index(size); diff --git a/sherpa-onnx/csrc/offline-lm-config.cc b/sherpa-onnx/csrc/offline-lm-config.cc new file mode 100644 index 00000000..429e5144 --- /dev/null +++ b/sherpa-onnx/csrc/offline-lm-config.cc @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/offline-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-lm-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineLMConfig::Register(ParseOptions *po) { + po->Register("lm", &model, "Path to LM model."); + po->Register("lm-scale", &scale, "LM scale."); +} + +bool OfflineLMConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineLMConfig::ToString() const { + std::ostringstream os; + + os << "OfflineLMConfig("; + os << "model=\"" << model << "\", "; + os << "scale=" << scale << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-lm-config.h b/sherpa-onnx/csrc/offline-lm-config.h new file mode 100644 index 00000000..4a1044af --- /dev/null +++ b/sherpa-onnx/csrc/offline-lm-config.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/offline-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineLMConfig { + // path to the onnx model + std::string model; + + // LM scale + float scale = 1.0; + + OfflineLMConfig() = default; + + OfflineLMConfig(const std::string &model, float scale) + : model(model), scale(scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-lm.cc b/sherpa-onnx/csrc/offline-lm.cc new file mode 100644 index 00000000..f76dcfd6 --- /dev/null +++ b/sherpa-onnx/csrc/offline-lm.cc @@ -0,0 +1,71 @@ +// sherpa-onnx/csrc/offline-lm.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-lm.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-rnn-lm.h" + +namespace sherpa_onnx { + +std::unique_ptr OfflineLM::Create(const OfflineLMConfig &config) { + return std::make_unique(config); +} + +void OfflineLM::ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + // compute the max token seq so that we know how much space to allocate + int32_t max_token_seq = 0; + int32_t num_hyps = 0; + + // we subtract context_size below since each token sequence is prepended + // with context_size blanks + for (const auto &h : *hyps) { + num_hyps += h.Size(); + for (const auto &t : h) { + max_token_seq = + std::max(max_token_seq, t.second.ys.size() - context_size); + } + } + + Ort::AllocatorWithDefaultOptions allocator; + std::array x_shape{num_hyps, max_token_seq}; + Ort::Value x = Ort::Value::CreateTensor(allocator, x_shape.data(), + x_shape.size()); + + std::array x_lens_shape{num_hyps}; + Ort::Value x_lens = Ort::Value::CreateTensor( + allocator, x_lens_shape.data(), x_lens_shape.size()); + + int64_t *p = x.GetTensorMutableData(); + std::fill(p, p + num_hyps * max_token_seq, 0); + + int64_t *p_lens = x_lens.GetTensorMutableData(); + + for (const auto &h : *hyps) { + for (const auto &t : h) { + const auto &ys = t.second.ys; + int32_t len = ys.size() - context_size; + std::copy(ys.begin() + context_size, ys.end(), p); + *p_lens = len; + + p += max_token_seq; + ++p_lens; + } + } + auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); + const float *p_nll = negative_loglike.GetTensorData(); + for (auto &h : *hyps) { + for (auto &t : h) { + // Use -scale here since we want to change negative loglike to loglike. + t.second.lm_log_prob = -scale * (*p_nll); + ++p_nll; + } + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-lm.h b/sherpa-onnx/csrc/offline-lm.h new file mode 100644 index 00000000..f99a8ad9 --- /dev/null +++ b/sherpa-onnx/csrc/offline-lm.h @@ -0,0 +1,46 @@ +// sherpa-onnx/csrc/offline-lm.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_ + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/offline-lm-config.h" + +namespace sherpa_onnx { + +class OfflineLM { + public: + virtual ~OfflineLM() = default; + + static std::unique_ptr Create(const OfflineLMConfig &config); + + /** Rescore a batch of sentences. + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param x_lens A 1-D tensor of shape (N,) with data type int64. + * It contains number of valid tokens in x before padding. + * @return Return a 1-D tensor of shape (N,) containing the negative log + * likelihood of each utterance. Its data type is float32. + * + * Caution: It returns negative log likelihood (nll), not log likelihood + */ + virtual Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) = 0; + + // This function updates hyp.lm_lob_prob of hyps. + // + // @param scale LM score + // @param context_size Context size of the transducer decoder model + // @param hyps It is changed in-place. + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps); +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 750951fc..d8360dce 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -16,6 +16,7 @@ #include "sherpa-onnx/csrc/offline-transducer-decoder.h" #include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/offline-transducer-model.h" +#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" #include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" @@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { decoder_ = std::make_unique(model_.get()); } else if (config_.decoding_method == "modified_beam_search") { - SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); - exit(-1); + if (!config_.lm_config.model.empty()) { + lm_ = OfflineLM::Create(config.lm_config); + } + + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config_.decoding_method.c_str()); @@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { SymbolTable symbol_table_; std::unique_ptr model_; std::unique_ptr decoder_; + std::unique_ptr lm_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index e0e3a651..c5daa17e 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -8,6 +8,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" namespace sherpa_onnx { @@ -15,13 +16,28 @@ namespace sherpa_onnx { void OfflineRecognizerConfig::Register(ParseOptions *po) { feat_config.Register(po); model_config.Register(po); + lm_config.Register(po); - po->Register("decoding-method", &decoding_method, - "decoding method," - "Valid values: greedy_search."); + po->Register( + "decoding-method", &decoding_method, + "decoding method," + "Valid values: greedy_search, modified_beam_search. " + "modified_beam_search is applicable only for transducer models."); + + po->Register("max-active-paths", &max_active_paths, + "Used only when decoding_method is modified_beam_search"); } bool OfflineRecognizerConfig::Validate() const { + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { + if (max_active_paths <= 0) { + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", + max_active_paths); + return false; + } + if (!lm_config.Validate()) return false; + } + return model_config.Validate(); } @@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const { os << "OfflineRecognizerConfig("; os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; - os << "decoding_method=\"" << decoding_method << "\")"; + os << "lm_config=" << lm_config.ToString() << ", "; + os << "decoding_method=\"" << decoding_method << "\", "; + os << "max_active_paths=" << max_active_paths << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 9cfd7d58..d6fcb390 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -9,6 +9,7 @@ #include #include +#include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-stream.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" @@ -21,18 +22,24 @@ struct OfflineRecognitionResult; struct OfflineRecognizerConfig { OfflineFeatureExtractorConfig feat_config; OfflineModelConfig model_config; + OfflineLMConfig lm_config; std::string decoding_method = "greedy_search"; + int32_t max_active_paths = 4; // only greedy_search is implemented // TODO(fangjun): Implement modified_beam_search OfflineRecognizerConfig() = default; OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, const OfflineModelConfig &model_config, - const std::string &decoding_method) + const OfflineLMConfig &lm_config, + const std::string &decoding_method, + int32_t max_active_paths) : feat_config(feat_config), model_config(model_config), - decoding_method(decoding_method) {} + lm_config(lm_config), + decoding_method(decoding_method), + max_active_paths(max_active_paths) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-rnn-lm.cc b/sherpa-onnx/csrc/offline-rnn-lm.cc new file mode 100644 index 00000000..a16118a7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-rnn-lm.cc @@ -0,0 +1,74 @@ +// sherpa-onnx/csrc/offline-rnn-lm.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-rnn-lm.h" + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineRnnLM::Impl { + public: + explicit Impl(const OfflineLMConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_{}, + allocator_{} { + Init(config); + } + + Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { + std::array inputs = {std::move(x), std::move(x_lens)}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + + private: + void Init(const OfflineLMConfig &config) { + auto buf = ReadFile(config_.model); + + sess_ = std::make_unique(env_, buf.data(), buf.size(), + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + } + + private: + OfflineLMConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; +}; + +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineRnnLM::~OfflineRnnLM() = default; + +Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) { + return impl_->Rescore(std::move(x), std::move(x_lens)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-rnn-lm.h b/sherpa-onnx/csrc/offline-rnn-lm.h new file mode 100644 index 00000000..e9f8cc97 --- /dev/null +++ b/sherpa-onnx/csrc/offline-rnn-lm.h @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/offline-rnn-lm.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-lm-config.h" +#include "sherpa-onnx/csrc/offline-lm.h" + +namespace sherpa_onnx { + +class OfflineRnnLM : public OfflineLM { + public: + ~OfflineRnnLM() override; + + explicit OfflineRnnLM(const OfflineLMConfig &config); + + /** Rescore a batch of sentences. + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param x_lens A 1-D tensor of shape (N,) with data type int64. + * It contains number of valid tokens in x before padding. + * @return Return a 1-D tensor of shape (N,) containing the log likelihood + * of each utterance. Its data type is float32. + * + * Caution: It returns log likelihood, not negative log likelihood (nll). + */ + Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index 6ecb94f8..f5d8e773 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl { std::copy(begin, end, p); p += context_size; } + + return decoder_input; + } + + Ort::Value BuildDecoderInput(const std::vector &results, + int32_t end_index) const { + assert(end_index <= results.size()); + + int32_t batch_size = end_index; + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + + Ort::Value decoder_input = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + int64_t *p = decoder_input.GetTensorMutableData(); + + for (int32_t i = 0; i != batch_size; ++i) { + const auto &r = results[i]; + const int64_t *begin = r.ys.data() + r.ys.size() - context_size; + const int64_t *end = r.ys.data() + r.ys.size(); + std::copy(begin, end, p); + p += context_size; + } + return decoder_input; } @@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput( return impl_->BuildDecoderInput(results, end_index); } +Ort::Value OfflineTransducerModel::BuildDecoderInput( + const std::vector &results, int32_t end_index) const { + return impl_->BuildDecoderInput(results, end_index); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-model.h b/sherpa-onnx/csrc/offline-transducer-model.h index 7f7d24d6..0b42a2fb 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.h +++ b/sherpa-onnx/csrc/offline-transducer-model.h @@ -9,6 +9,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/offline-model-config.h" namespace sherpa_onnx { @@ -79,13 +80,16 @@ class OfflineTransducerModel { * * @param results Current decoded results. * @param end_index We only use results[0:end_index] to build - * the decoder_input. + * the decoder_input. results[end_index] is not used. * @return Return a tensor of shape (results.size(), ContextSize()) */ Ort::Value BuildDecoderInput( const std::vector &results, int32_t end_index) const; + Ort::Value BuildDecoderInput(const std::vector &results, + int32_t end_index) const; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc new file mode 100644 index 00000000..c4aa2a75 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -0,0 +1,163 @@ +// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/packed-sequence.h" +#include "sherpa-onnx/csrc/slice.h" + +namespace sherpa_onnx { + +static std::vector GetHypsRowSplits( + const std::vector &hyps) { + std::vector row_splits; + row_splits.reserve(hyps.size() + 1); + + row_splits.push_back(0); + int32_t s = 0; + for (const auto &h : hyps) { + s += h.Size(); + row_splits.push_back(s); + } + + return row_splits; +} + +std::vector +OfflineTransducerModifiedBeamSearchDecoder::Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length) { + PackedSequence packed_encoder_out = PackPaddedSequence( + model_->Allocator(), &encoder_out, &encoder_out_length); + + int32_t batch_size = + static_cast(packed_encoder_out.sorted_indexes.size()); + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + std::vector blanks(context_size, 0); + Hypotheses blank_hyp({{blanks, 0}}); + + std::deque finalized; + std::vector cur(batch_size, blank_hyp); + std::vector prev; + + int32_t start = 0; + int32_t t = 0; + for (auto n : packed_encoder_out.batch_sizes) { + Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n); + start += n; + + if (n < static_cast(cur.size())) { + for (int32_t k = static_cast(cur.size()) - 1; k >= n; --k) { + finalized.push_front(std::move(cur[k])); + } + + cur.erase(cur.begin() + n, cur.end()); + } // if (n < static_cast(cur.size())) + + // Due to merging paths with identical token sequences, + // not all utterances have "max_active_paths" paths. + auto hyps_row_splits = GetHypsRowSplits(cur); + int32_t num_hyps = hyps_row_splits.back(); + + prev.clear(); + prev.reserve(num_hyps); + + for (auto &hyps : cur) { + for (auto &h : hyps) { + prev.push_back(std::move(h.second)); + } + } + cur.clear(); + cur.reserve(n); + + auto decoder_input = model_->BuildDecoderInput(prev, num_hyps); + // decoder_input shape: (num_hyps, context_size) + + auto decoder_out = model_->RunDecoder(std::move(decoder_input)); + // decoder_out is (num_hyps, joiner_dim) + + cur_encoder_out = + Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); + // now cur_encoder_out is of shape (num_hyps, joiner_dim) + + Ort::Value logit = model_->RunJoiner( + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + + float *p_logit = logit.GetTensorMutableData(); + LogSoftmax(p_logit, vocab_size, num_hyps); + + // now p_logit contains log_softmax output, we rename it to p_logprob + // to match what it actually contains + float *p_logprob = p_logit; + + // add log_prob of each hypothesis to p_logprob before taking top_k + for (int32_t i = 0; i != num_hyps; ++i) { + float log_prob = prev[i].log_prob; + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { + *p_logprob += log_prob; + } + } + p_logprob = p_logit; // we changed p_logprob in the above for loop + + // Now compute top_k for each utterance + for (int32_t i = 0; i != n; ++i) { + int32_t start = hyps_row_splits[i]; + int32_t end = hyps_row_splits[i + 1]; + auto topk = + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); + + Hypotheses hyps; + for (auto k : topk) { + int32_t hyp_index = k / vocab_size + start; + int32_t new_token = k % vocab_size; + Hypothesis new_hyp = prev[hyp_index]; + + if (new_token != 0) { + // blank id is fixed to 0 + new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t); + } + + new_hyp.log_prob = p_logprob[k]; + hyps.Add(std::move(new_hyp)); + } // for (auto k : topk) + p_logprob += (end - start) * vocab_size; + cur.push_back(std::move(hyps)); + } // for (int32_t i = 0; i != n; ++i) + + ++t; + } // for (auto n : packed_encoder_out.batch_sizes) + + for (auto &h : finalized) { + cur.push_back(std::move(h)); + } + + if (lm_) { + // use LM for rescoring + lm_->ComputeLMScore(lm_scale_, context_size, &cur); + } + + std::vector unsorted_ans(batch_size); + for (int32_t i = 0; i != batch_size; ++i) { + Hypothesis hyp = cur[i].GetMostProbable(true); + + auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]]; + + // strip leading blanks + r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()}; + r.timestamps = std::move(hyp.timestamps); + } + + return unsorted_ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h new file mode 100644 index 00000000..5f40dc29 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-lm.h" +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-model.h" + +namespace sherpa_onnx { + +class OfflineTransducerModifiedBeamSearchDecoder + : public OfflineTransducerDecoder { + public: + OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, + OfflineLM *lm, + int32_t max_active_paths, + float lm_scale) + : model_(model), + lm_(lm), + max_active_paths_(max_active_paths), + lm_scale_(lm_scale) {} + + std::vector Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length) override; + + private: + OfflineTransducerModel *model_; // Not owned + OfflineLM *lm_; // Not owned; may be nullptr + + int32_t max_active_paths_; + float lm_scale_; // used only when lm_ is not nullptr +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-lm-config.cc b/sherpa-onnx/csrc/online-lm-config.cc new file mode 100644 index 00000000..80f597e7 --- /dev/null +++ b/sherpa-onnx/csrc/online-lm-config.cc @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/online-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-lm-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineLMConfig::Register(ParseOptions *po) { + po->Register("lm", &model, "Path to LM model."); + po->Register("lm-scale", &scale, "LM scale."); +} + +bool OnlineLMConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OnlineLMConfig::ToString() const { + std::ostringstream os; + + os << "OnlineLMConfig("; + os << "model=\"" << model << "\", "; + os << "scale=" << scale << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-lm-config.h b/sherpa-onnx/csrc/online-lm-config.h new file mode 100644 index 00000000..29687764 --- /dev/null +++ b/sherpa-onnx/csrc/online-lm-config.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/online-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlineLMConfig { + // path to the onnx model + std::string model; + + // LM scale + float scale = 1.0; + + OnlineLMConfig() = default; + + OnlineLMConfig(const std::string &model, float scale) + : model(model), scale(scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-lm.h b/sherpa-onnx/csrc/online-lm.h new file mode 100644 index 00000000..9e2dbaf1 --- /dev/null +++ b/sherpa-onnx/csrc/online-lm.h @@ -0,0 +1,64 @@ +// sherpa-onnx/csrc/online-lm.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_ +#define SHERPA_ONNX_CSRC_ONLINE_LM_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/online-lm-config.h" + +namespace sherpa_onnx { + +class OnlineLM { + public: + virtual ~OnlineLM() = default; + + static std::unique_ptr Create(const OnlineLMConfig &config); + + virtual std::vector GetInitStates() = 0; + + /** Rescore a batch of sentences. + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param y A 2-D tensor of shape (N, L) with data type int64. + * @param states It contains the states for the LM model + * @return Return a pair containingo + * - negative loglike + * - updated states + * + * Caution: It returns negative log likelihood (nll), not log likelihood + */ + std::pair> Ort::Value Rescore( + Ort::Value x, Ort::Value y, std::vector states) = 0; + + // This function updates hyp.lm_lob_prob of hyps. + // + // @param scale LM score + // @param context_size Context size of the transducer decoder model + // @param hyps It is changed in-place. + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps); + /** TODO(fangjun): + * + * 1. Add two fields to Hypothesis + * (a) int32_t lm_cur_pos = 0; number of scored tokens so far + * (b) std::vector lm_states; + * 2. When we want to score a hypothesis, we construct x and y as follows: + * + * std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos, + * hyp.ys.end() - 1}; + * std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1 + * hyp.ys.end()}; + * hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos; + */ +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_ diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index bc6f8553..b3ca7c2a 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -36,38 +36,6 @@ static void UseCachedDecoderOut( } } -static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, - const std::vector &hyps_num_split) { - std::vector cur_encoder_out_shape = - cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); - - std::array ans_shape{hyps_num_split.back(), - cur_encoder_out_shape[1]}; - - Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), - ans_shape.size()); - - const float *src = cur_encoder_out->GetTensorData(); - float *dst = ans.GetTensorMutableData(); - int32_t batch_size = static_cast(hyps_num_split.size()) - 1; - for (int32_t b = 0; b != batch_size; ++b) { - int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; - for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { - std::copy(src, src + cur_encoder_out_shape[1], dst); - dst += cur_encoder_out_shape[1]; - } - src += cur_encoder_out_shape[1]; - } - return ans; -} - -static void LogSoftmax(float *in, int32_t w, int32_t h) { - for (int32_t i = 0; i != h; ++i) { - LogSoftmax(in, w); - in += w; - } -} - OnlineTransducerDecoderResult OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 133d3df3..84ea8b26 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -193,4 +193,29 @@ std::vector ReadFile(AAssetManager *mgr, const std::string &filename) { } #endif +Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, + const std::vector &hyps_num_split) { + std::vector cur_encoder_out_shape = + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); + + std::array ans_shape{hyps_num_split.back(), + cur_encoder_out_shape[1]}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + + const float *src = cur_encoder_out->GetTensorData(); + float *dst = ans.GetTensorMutableData(); + int32_t batch_size = static_cast(hyps_num_split.size()) - 1; + for (int32_t b = 0; b != batch_size; ++b) { + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { + std::copy(src, src + cur_encoder_out_shape[1], dst); + dst += cur_encoder_out_shape[1]; + } + src += cur_encoder_out_shape[1]; + } + return ans; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 8bdba7a4..68aca4d4 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -86,6 +86,9 @@ std::vector ReadFile(const std::string &filename); std::vector ReadFile(AAssetManager *mgr, const std::string &filename); #endif +// TODO(fangjun): Document it +Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, + const std::vector &hyps_num_split); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index 601da3ea..b6d8916f 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -111,6 +111,9 @@ for a list of pre-trained models to download. fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); + if (config.decoding_method == "modified_beam_search") { + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); + } fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); float rtf = elapsed_seconds / duration; diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 04fcbeab..578e8ba1 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -117,6 +117,9 @@ for a list of pre-trained models to download. fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); + if (config.decoding_method == "modified_beam_search") { + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); + } fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); float rtf = elapsed_seconds / duration; diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index f32735db..5e5fb23f 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx display.cc endpoint.cc features.cc + offline-lm-config.cc offline-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc offline-paraformer-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-lm-config.cc b/sherpa-onnx/python/csrc/offline-lm-config.cc new file mode 100644 index 00000000..a5f58cfd --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-lm-config.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/python/csrc/offline-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-lm-config.h" + +#include + +#include "sherpa-onnx//csrc/offline-lm-config.h" + +namespace sherpa_onnx { + +void PybindOfflineLMConfig(py::module *m) { + using PyClass = OfflineLMConfig; + py::class_(*m, "OfflineLMConfig") + .def(py::init(), py::arg("model"), + py::arg("scale")) + .def_readwrite("model", &PyClass::model) + .def_readwrite("scale", &PyClass::scale) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-lm-config.h b/sherpa-onnx/python/csrc/offline-lm-config.h new file mode 100644 index 00000000..35b90af4 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-lm-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineLMConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 8a7779ba..7458181c 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { using PyClass = OfflineRecognizerConfig; py::class_(*m, "OfflineRecognizerConfig") .def(py::init(), + const OfflineModelConfig &, const OfflineLMConfig &, + const std::string &, int32_t>(), py::arg("feat_config"), py::arg("model_config"), - py::arg("decoding_method")) + py::arg("lm_config") = OfflineLMConfig(), + py::arg("decoding_method") = "greedy_search", + py::arg("max_active_paths") = 4) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("decoding_method", &PyClass::decoding_method) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index e099bd79..7d3b0d05 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -7,6 +7,7 @@ #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/offline-lm-config.h" #include "sherpa-onnx/python/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-recognizer.h" #include "sherpa-onnx/python/csrc/offline-stream.h" @@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindDisplay(&m); PybindOfflineStream(&m); + PybindOfflineLMConfig(&m); PybindOfflineModelConfig(&m); PybindOfflineRecognizer(&m); }