// sherpa-onnx/csrc/on-rnn-lm.cc // // Copyright (c) 2023 Pingfeng Luo // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-rnn-lm.h" #include #include #include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/session.h" namespace sherpa_onnx { class OnlineRnnLM::Impl { public: explicit Impl(const OnlineLMConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_{GetSessionOptions(config)}, allocator_{} { Init(config); } void ComputeLMScore(float scale, Hypothesis *hyp) { if (hyp->nn_lm_states.empty()) { auto init_states = GetInitStates(); hyp->nn_lm_scores.value = std::move(init_states.first); hyp->nn_lm_states = Convert(std::move(init_states.second)); } // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData(); hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; // get lm scores for next tokens given the hyp->ys[:] and save to // nn_lm_scores std::array x_shape{1, 1}; Ort::Value x = Ort::Value::CreateTensor(allocator_, x_shape.data(), x_shape.size()); *x.GetTensorMutableData() = hyp->ys.back(); auto lm_out = ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); hyp->nn_lm_scores.value = std::move(lm_out.first); hyp->nn_lm_states = Convert(std::move(lm_out.second)); } std::pair> ScoreToken( Ort::Value x, std::vector states) { std::array inputs = {std::move(x), std::move(states[0]), std::move(states[1])}; auto out = sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), output_names_ptr_.data(), output_names_ptr_.size()); std::vector next_states; next_states.reserve(2); next_states.push_back(std::move(out[1])); next_states.push_back(std::move(out[2])); return {std::move(out[0]), std::move(next_states)}; } std::pair> GetInitStates() const { std::vector ans; ans.reserve(init_states_.size()); for (const auto &s : init_states_) { ans.emplace_back(Clone(allocator_, &s)); } return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; } private: void Init(const OnlineLMConfig &config) { auto buf = ReadFile(config_.model); sess_ = std::make_unique(env_, buf.data(), buf.size(), sess_opts_); GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); Ort::AllocatorWithDefaultOptions allocator; // used in the macro below SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers"); SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size"); SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); ComputeInitStates(); } void ComputeInitStates() { constexpr int32_t kBatchSize = 1; std::array h_shape{rnn_num_layers_, kBatchSize, rnn_hidden_size_}; std::array c_shape{rnn_num_layers_, kBatchSize, rnn_hidden_size_}; Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), h_shape.size()); Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), c_shape.size()); Fill(&h, 0); Fill(&c, 0); std::array x_shape{1, 1}; Ort::Value x = Ort::Value::CreateTensor(allocator_, x_shape.data(), x_shape.size()); *x.GetTensorMutableData() = sos_id_; std::vector states; states.push_back(std::move(h)); states.push_back(std::move(c)); auto pair = ScoreToken(std::move(x), std::move(states)); init_scores_.value = std::move(pair.first); init_states_ = std::move(pair.second); } private: OnlineLMConfig config_; Ort::Env env_; Ort::SessionOptions sess_opts_; Ort::AllocatorWithDefaultOptions allocator_; std::unique_ptr sess_; std::vector input_names_; std::vector input_names_ptr_; std::vector output_names_; std::vector output_names_ptr_; CopyableOrtValue init_scores_; std::vector init_states_; int32_t rnn_num_layers_ = 2; int32_t rnn_hidden_size_ = 512; int32_t sos_id_ = 1; }; OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) : impl_(std::make_unique(config)) {} OnlineRnnLM::~OnlineRnnLM() = default; std::pair> OnlineRnnLM::GetInitStates() { return impl_->GetInitStates(); } std::pair> OnlineRnnLM::ScoreToken( Ort::Value x, std::vector states) { return impl_->ScoreToken(std::move(x), std::move(states)); } void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) { return impl_->ComputeLMScore(scale, hyp); } } // namespace sherpa_onnx