diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index a653b639..f8822a79 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() { val config = OnlineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), modelConfig = getModelConfig(type = type)!!, + lmConfig = getOnlineLMConfig(type = type), endpointConfig = getEndpointConfig(), enableEndpoint = true, - decodingMethod = "greedy_search", + decodingMethod = "modified_beam_search", maxActivePaths = 4, ) diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index 47a806ee..d3133829 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig( var debug: Boolean = false, ) +data class OnlineLMConfig( + var model: String = "", + var scale: Float = 0.5f, +) + data class FeatureConfig( var sampleRate: Int = 16000, var featureDim: Int = 80, @@ -31,6 +36,7 @@ data class FeatureConfig( data class OnlineRecognizerConfig( var featConfig: FeatureConfig = FeatureConfig(), var modelConfig: OnlineTransducerModelConfig, + var lmConfig : OnlineLMConfig, var endpointConfig: EndpointConfig = EndpointConfig(), var enableEndpoint: Boolean = true, var decodingMethod: String = "greedy_search", @@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { return null; } +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own LM model. (It should be straightforward to train a new NN LM model +by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py) + +@param type +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + */ +fun getOnlineLMConfig(type : Int): OnlineLMConfig { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" + return OnlineLMConfig( + model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx", + scale = 0.5f, + ) + } + } + return OnlineLMConfig(); +} + fun getEndpointConfig(): EndpointConfig { return EndpointConfig( rule1 = EndpointRule(false, 2.4f, 0.0f), diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index abf1cf2f..e62ad488 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -22,8 +22,11 @@ fun main() { var endpointConfig = EndpointConfig() + var lmConfig = OnlineLMConfig() + var config = OnlineRecognizerConfig( modelConfig = modelConfig, + lmConfig = lmConfig, featConfig = featConfig, endpointConfig = endpointConfig, enableEndpoint = true, diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 883eff96..3216af65 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -34,9 +34,11 @@ set(sources offline-transducer-model-config.cc offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc + online-lm.cc online-lm-config.cc online-lstm-transducer-model.cc online-recognizer.cc + online-rnn-lm.cc online-stream.cc online-transducer-decoder.cc online-transducer-greedy-search-decoder.cc diff --git a/sherpa-onnx/csrc/hypothesis.cc b/sherpa-onnx/csrc/hypothesis.cc index f07a161c..3dbd33ba 100644 --- a/sherpa-onnx/csrc/hypothesis.cc +++ b/sherpa-onnx/csrc/hypothesis.cc @@ -1,6 +1,6 @@ /** * Copyright (c) 2023 Xiaomi Corporation - * + * Copyright (c) 2023 Pingfeng Luo */ #include "sherpa-onnx/csrc/hypothesis.h" diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 198e40a7..a0097f52 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -1,5 +1,6 @@ /** * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023 Pingfeng Luo * */ @@ -12,7 +13,9 @@ #include #include +#include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { @@ -31,6 +34,13 @@ struct Hypothesis { // LM log prob if any. double lm_log_prob = 0; + int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM + std::vector nn_lm_states; + + // TODO(fangjun): Make it configurable + // the minimum of tokens in a chunk for streaming RNN LM + int32_t lm_rescore_min_chunk = 2; // a const + int32_t num_trailing_blanks = 0; Hypothesis() = default; diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h index d63c3ed3..086b064e 100644 --- a/sherpa-onnx/csrc/math.h +++ b/sherpa-onnx/csrc/math.h @@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { } } -// TODO(fangjun): use std::partial_sort to replace std::sort. -// Remember also to fix sherpa-ncnn template std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { std::vector vec_index(size); std::iota(vec_index.begin(), vec_index.end(), 0); - std::sort(vec_index.begin(), vec_index.end(), - [vec](int32_t index_1, int32_t index_2) { - return vec[index_1] > vec[index_2]; - }); + std::partial_sort(vec_index.begin(), vec_index.begin() + topk, + vec_index.end(), [vec](int32_t index_1, int32_t index_2) { + return vec[index_1] > vec[index_2]; + }); int32_t k_num = std::min(size, topk); std::vector index(vec_index.begin(), vec_index.begin() + k_num); diff --git a/sherpa-onnx/csrc/offline-lm-config.h b/sherpa-onnx/csrc/offline-lm-config.h index 4a1044af..1a35dc85 100644 --- a/sherpa-onnx/csrc/offline-lm-config.h +++ b/sherpa-onnx/csrc/offline-lm-config.h @@ -15,7 +15,7 @@ struct OfflineLMConfig { std::string model; // LM scale - float scale = 1.0; + float scale = 0.5; OfflineLMConfig() = default; diff --git a/sherpa-onnx/csrc/online-lm-config.h b/sherpa-onnx/csrc/online-lm-config.h index 29687764..8bb4ab53 100644 --- a/sherpa-onnx/csrc/online-lm-config.h +++ b/sherpa-onnx/csrc/online-lm-config.h @@ -15,7 +15,7 @@ struct OnlineLMConfig { std::string model; // LM scale - float scale = 1.0; + float scale = 0.5; OnlineLMConfig() = default; diff --git a/sherpa-onnx/csrc/online-lm.cc b/sherpa-onnx/csrc/online-lm.cc new file mode 100644 index 00000000..11283e11 --- /dev/null +++ b/sherpa-onnx/csrc/online-lm.cc @@ -0,0 +1,92 @@ +// sherpa-onnx/csrc/online-lm.cc +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-lm.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/online-rnn-lm.h" + +namespace sherpa_onnx { + +static std::vector Convert(std::vector values) { + std::vector ans; + ans.reserve(values.size()); + + for (auto &v : values) { + ans.emplace_back(std::move(v)); + } + + return ans; +} + +static std::vector Convert(std::vector values) { + std::vector ans; + ans.reserve(values.size()); + + for (auto &v : values) { + ans.emplace_back(std::move(v.value)); + } + + return ans; +} + +std::unique_ptr OnlineLM::Create(const OnlineLMConfig &config) { + return std::make_unique(config); +} + +void OnlineLM::ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + Ort::AllocatorWithDefaultOptions allocator; + + for (auto &hyp : *hyps) { + for (auto &h_m : hyp) { + auto &h = h_m.second; + auto &ys = h.ys; + const int32_t token_num_in_chunk = + ys.size() - context_size - h.cur_scored_pos - 1; + + if (token_num_in_chunk < 1) { + continue; + } + + if (h.nn_lm_states.empty()) { + h.nn_lm_states = Convert(GetInitStates()); + } + + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { + std::array x_shape{1, token_num_in_chunk}; + // shape of x and y are same + Ort::Value x = Ort::Value::CreateTensor( + allocator, x_shape.data(), x_shape.size()); + Ort::Value y = Ort::Value::CreateTensor( + allocator, x_shape.data(), x_shape.size()); + int64_t *p_x = x.GetTensorMutableData(); + int64_t *p_y = y.GetTensorMutableData(); + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, + p_x); + std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(), + p_y); + + // streaming forward by NN LM + auto out = Rescore(std::move(x), std::move(y), + Convert(std::move(h.nn_lm_states))); + + // update NN LM score in hyp + const float *p_nll = out.first.GetTensorData(); + h.lm_log_prob = -scale * (*p_nll); + + // update NN LM states in hyp + h.nn_lm_states = Convert(std::move(out.second)); + + h.cur_scored_pos += token_num_in_chunk; + } + } + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-lm.h b/sherpa-onnx/csrc/online-lm.h index 9e2dbaf1..cc4a7de6 100644 --- a/sherpa-onnx/csrc/online-lm.h +++ b/sherpa-onnx/csrc/online-lm.h @@ -34,7 +34,7 @@ class OnlineLM { * * Caution: It returns negative log likelihood (nll), not log likelihood */ - std::pair> Ort::Value Rescore( + virtual std::pair> Rescore( Ort::Value x, Ort::Value y, std::vector states) = 0; // This function updates hyp.lm_lob_prob of hyps. @@ -44,19 +44,6 @@ class OnlineLM { // @param hyps It is changed in-place. void ComputeLMScore(float scale, int32_t context_size, std::vector *hyps); - /** TODO(fangjun): - * - * 1. Add two fields to Hypothesis - * (a) int32_t lm_cur_pos = 0; number of scored tokens so far - * (b) std::vector lm_states; - * 2. When we want to score a hypothesis, we construct x and y as follows: - * - * std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos, - * hyp.ys.end() - 1}; - * std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1 - * hyp.ys.end()}; - * hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos; - */ }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index cc15487b..0cd68653 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -16,6 +16,8 @@ #include "nlohmann/json.hpp" #include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-lm.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { feat_config.Register(po); model_config.Register(po); endpoint_config.Register(po); + lm_config.Register(po); po->Register("enable-endpoint", &enable_endpoint, "True to enable endpoint detection. False to disable it."); @@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { } bool OnlineRecognizerConfig::Validate() const { + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { + if (max_active_paths <= 0) { + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", + max_active_paths); + return false; + } + if (!lm_config.Validate()) return false; + } return model_config.Validate(); } @@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const { os << "OnlineRecognizerConfig("; os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; + os << "lm_config=" << lm_config.ToString() << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", "; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; os << "max_active_paths=" << max_active_paths << ", "; @@ -116,8 +128,13 @@ class OnlineRecognizer::Impl { sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { if (config.decoding_method == "modified_beam_search") { + if (!config_.lm_config.model.empty()) { + lm_ = OnlineLM::Create(config.lm_config); + } + decoder_ = std::make_unique( - model_.get(), config_.max_active_paths); + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique(model_.get()); @@ -136,7 +153,8 @@ class OnlineRecognizer::Impl { endpoint_(config_.endpoint_config) { if (config.decoding_method == "modified_beam_search") { decoder_ = std::make_unique( - model_.get(), config_.max_active_paths); + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique(model_.get()); @@ -246,6 +264,7 @@ class OnlineRecognizer::Impl { private: OnlineRecognizerConfig config_; std::unique_ptr model_; + std::unique_ptr lm_; std::unique_ptr decoder_; SymbolTable sym_; Endpoint endpoint_; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 6a59dc3e..f2427ed3 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -16,6 +16,7 @@ #include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -67,10 +68,11 @@ struct OnlineRecognizerResult { struct OnlineRecognizerConfig { FeatureExtractorConfig feat_config; OnlineTransducerModelConfig model_config; + OnlineLMConfig lm_config; EndpointConfig endpoint_config; bool enable_endpoint = true; - std::string decoding_method = "greedy_search"; + std::string decoding_method = "modified_beam_search"; // now support modified_beam_search and greedy_search int32_t max_active_paths = 4; // used only for modified_beam_search @@ -79,6 +81,7 @@ struct OnlineRecognizerConfig { OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, const OnlineTransducerModelConfig &model_config, + const OnlineLMConfig &lm_config, const EndpointConfig &endpoint_config, bool enable_endpoint, const std::string &decoding_method, diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc new file mode 100644 index 00000000..611e0c40 --- /dev/null +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -0,0 +1,140 @@ +// sherpa-onnx/csrc/on-rnn-lm.cc +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-rnn-lm.h" + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OnlineRnnLM::Impl { + public: + explicit Impl(const OnlineLMConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_{}, + allocator_{} { + Init(config); + } + + std::pair> Rescore( + Ort::Value x, Ort::Value y, std::vector states) { + std::array inputs = { + std::move(x), std::move(y), std::move(states[0]), std::move(states[1])}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(2); + next_states.push_back(std::move(out[1])); + next_states.push_back(std::move(out[2])); + + return {std::move(out[0]), std::move(next_states)}; + } + + std::vector GetInitStates() const { + std::vector ans; + ans.reserve(init_states_.size()); + + for (const auto &s : init_states_) { + ans.emplace_back(Clone(allocator_, &s)); + } + + return ans; + } + + private: + void Init(const OnlineLMConfig &config) { + auto buf = ReadFile(config_.model); + + sess_ = std::make_unique(env_, buf.data(), buf.size(), + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers"); + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size"); + SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); + + ComputeInitStates(); + } + + void ComputeInitStates() { + constexpr int32_t kBatchSize = 1; + std::array h_shape{rnn_num_layers_, kBatchSize, + rnn_hidden_size_}; + std::array c_shape{rnn_num_layers_, kBatchSize, + rnn_hidden_size_}; + Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), + h_shape.size()); + Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), + c_shape.size()); + Fill(&h, 0); + Fill(&c, 0); + std::array x_shape{1, 1}; + // shape of x and y are same + Ort::Value x = Ort::Value::CreateTensor(allocator_, x_shape.data(), + x_shape.size()); + Ort::Value y = Ort::Value::CreateTensor(allocator_, x_shape.data(), + x_shape.size()); + *x.GetTensorMutableData() = sos_id_; + *y.GetTensorMutableData() = sos_id_; + + std::vector states; + states.push_back(std::move(h)); + states.push_back(std::move(c)); + auto pair = Rescore(std::move(x), std::move(y), std::move(states)); + + init_states_ = std::move(pair.second); + } + + private: + OnlineLMConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector init_states_; + + int32_t rnn_num_layers_ = 2; + int32_t rnn_hidden_size_ = 512; + int32_t sos_id_ = 1; +}; + +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) + : impl_(std::make_unique(config)) {} + +OnlineRnnLM::~OnlineRnnLM() = default; + +std::vector OnlineRnnLM::GetInitStates() { + return impl_->GetInitStates(); +} + +std::pair> OnlineRnnLM::Rescore( + Ort::Value x, Ort::Value y, std::vector states) { + return impl_->Rescore(std::move(x), std::move(y), std::move(states)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-rnn-lm.h b/sherpa-onnx/csrc/online-rnn-lm.h new file mode 100644 index 00000000..fcb2b17e --- /dev/null +++ b/sherpa-onnx/csrc/online-rnn-lm.h @@ -0,0 +1,48 @@ +// sherpa-onnx/csrc/online-rnn-lm.h +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-lm-config.h" +#include "sherpa-onnx/csrc/online-lm.h" + +namespace sherpa_onnx { + +class OnlineRnnLM : public OnlineLM { + public: + ~OnlineRnnLM() override; + + explicit OnlineRnnLM(const OnlineLMConfig &config); + + std::vector GetInitStates() override; + + /** Rescore a batch of sentences. + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param y A 2-D tensor of shape (N, L) with data type int64. + * @param states It contains the states for the LM model + * @return Return a pair containingo + * - negative loglike + * - updated states + * + * Caution: It returns negative log likelihood (nll), not log likelihood + */ + std::pair> Rescore( + Ort::Value x, Ort::Value y, std::vector states) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 46e5319a..1a0cf760 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( } // for (int32_t b = 0; b != batch_size; ++b) } + if (lm_) { + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); + } + for (int32_t b = 0; b != batch_size; ++b) { auto &hyps = cur[b]; auto best_hyp = hyps.GetMostProbable(true); diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 86df4d72..5fbf6a31 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -8,6 +8,7 @@ #include +#include "sherpa-onnx/csrc/online-lm.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder : public OnlineTransducerDecoder { public: OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, - int32_t max_active_paths) - : model_(model), max_active_paths_(max_active_paths) {} + OnlineLM *lm, + int32_t max_active_paths, + float lm_scale) + : model_(model), + lm_(lm), + max_active_paths_(max_active_paths), + lm_scale_(lm_scale) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder private: OnlineTransducerModel *model_; // Not owned + OnlineLM *lm_; // Not owned + int32_t max_active_paths_; + float lm_scale_; // used only when lm_ is not nullptr }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 84ea8b26..883f5afd 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -1,13 +1,13 @@ // sherpa-onnx/csrc/onnx-utils.cc // // Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo #include "sherpa-onnx/csrc/onnx-utils.h" #include #include #include #include -#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, return ans; } +CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) { + *this = other; +} + +CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) { + if (this == &other) { + return *this; + } + if (other.value) { + Ort::AllocatorWithDefaultOptions allocator; + value = Clone(allocator, &other.value); + } + return *this; +} + +CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) { + *this = std::move(other); +} + +CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) { + if (this == &other) { + return *this; + } + value = std::move(other.value); + return *this; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 68aca4d4..113fa5de 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/onnx-utils.h // // Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ #define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ @@ -13,6 +14,7 @@ #include #include #include +#include #include #if __ANDROID_API__ >= 9 @@ -89,6 +91,24 @@ std::vector ReadFile(AAssetManager *mgr, const std::string &filename); // TODO(fangjun): Document it Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, const std::vector &hyps_num_split); + +struct CopyableOrtValue { + Ort::Value value{nullptr}; + + CopyableOrtValue() = default; + + /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT + : value(std::move(v)) {} + + CopyableOrtValue(const CopyableOrtValue &other); + + CopyableOrtValue &operator=(const CopyableOrtValue &other); + + CopyableOrtValue(CopyableOrtValue &&other); + + CopyableOrtValue &operator=(CopyableOrtValue &&other); +}; + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 578e8ba1..5cb6ca3e 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -13,8 +13,9 @@ #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/wave-reader.h" +// TODO(fangjun): Use ParseOptions as we are getting more args int main(int32_t argc, char *argv[]) { - if (argc < 6 || argc > 8) { + if (argc < 6 || argc > 9) { const char *usage = R"usage( Usage: ./bin/sherpa-onnx \ @@ -22,7 +23,7 @@ Usage: /path/to/encoder.onnx \ /path/to/decoder.onnx \ /path/to/joiner.onnx \ - /path/to/foo.wav [num_threads [decoding_method]] + /path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]] Default value for num_threads is 2. Valid values for decoding_method: greedy_search (default), modified_beam_search. @@ -53,10 +54,12 @@ for a list of pre-trained models to download. if (argc == 7 && atoi(argv[6]) > 0) { config.model_config.num_threads = atoi(argv[6]); } - if (argc == 8) { config.decoding_method = argv[7]; } + if (argc == 9) { + config.lm_config.model = argv[8]; + } config.max_active_paths = 4; fprintf(stderr, "%s\n", config.ToString().c_str()); diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 62165462..ab5cec5d 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -16,9 +16,8 @@ #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" -#else -#include #endif +#include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer.h" @@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "debug", "Z"); ans.model_config.debug = env->GetBooleanField(model_config, fid); + //---------- rnn lm model config ---------- + fid = env->GetFieldID(cls, "lmConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); + jobject lm_model_config = env->GetObjectField(config, fid); + jclass lm_model_config_cls = env->GetObjectClass(lm_model_config); + + fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(lm_model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.lm_config.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); + ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); + return ans; } diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 5e5fb23f..ce62a36c 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx offline-recognizer.cc offline-stream.cc offline-transducer-model-config.cc + online-lm-config.cc online-recognizer.cc online-stream.cc online-transducer-model-config.cc diff --git a/sherpa-onnx/python/csrc/online-lm-config.cc b/sherpa-onnx/python/csrc/online-lm-config.cc new file mode 100644 index 00000000..f7097e49 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-lm-config.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/python/csrc/online-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-lm-config.h" + +#include + +#include "sherpa-onnx//csrc/online-lm-config.h" + +namespace sherpa_onnx { + +void PybindOnlineLMConfig(py::module *m) { + using PyClass = OnlineLMConfig; + py::class_(*m, "OnlineLMConfig") + .def(py::init(), py::arg("model") = "", + py::arg("scale") = 0.5f) + .def_readwrite("model", &PyClass::model) + .def_readwrite("scale", &PyClass::scale) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-lm-config.h b/sherpa-onnx/python/csrc/online-lm-config.h new file mode 100644 index 00000000..d41030f8 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-lm-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineLMConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index d5cb70f6..54d97e80 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") .def(py::init(), + const OnlineTransducerModelConfig &, const OnlineLMConfig &, + const EndpointConfig &, bool, const std::string &, + int32_t>(), py::arg("feat_config"), py::arg("model_config"), - py::arg("endpoint_config"), py::arg("enable_endpoint"), - py::arg("decoding_method"), py::arg("max_active_paths")) + py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths")) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 7d3b0d05..0850ee34 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -11,6 +11,7 @@ #include "sherpa-onnx/python/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-recognizer.h" #include "sherpa-onnx/python/csrc/offline-stream.h" +#include "sherpa-onnx/python/csrc/online-lm-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" @@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindFeatures(&m); PybindOnlineTransducerModelConfig(&m); + PybindOnlineLMConfig(&m); PybindOnlineStream(&m); PybindEndpoint(&m); PybindOnlineRecognizer(&m);