// sherpa-onnx/csrc/online-lm.cc // // Copyright (c) 2023 Pingfeng Luo // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-lm.h" #include #include #include #include "sherpa-onnx/csrc/online-rnn-lm.h" namespace sherpa_onnx { static std::vector Convert(std::vector values) { std::vector ans; ans.reserve(values.size()); for (auto &v : values) { ans.emplace_back(std::move(v)); } return ans; } static std::vector Convert(std::vector values) { std::vector ans; ans.reserve(values.size()); for (auto &v : values) { ans.emplace_back(std::move(v.value)); } return ans; } std::unique_ptr OnlineLM::Create(const OnlineLMConfig &config) { return std::make_unique(config); } void OnlineLM::ComputeLMScore(float scale, int32_t context_size, std::vector *hyps) { Ort::AllocatorWithDefaultOptions allocator; for (auto &hyp : *hyps) { for (auto &h_m : hyp) { auto &h = h_m.second; auto &ys = h.ys; const int32_t token_num_in_chunk = ys.size() - context_size - h.cur_scored_pos - 1; if (token_num_in_chunk < 1) { continue; } if (h.nn_lm_states.empty()) { h.nn_lm_states = Convert(GetInitStates()); } if (token_num_in_chunk >= h.lm_rescore_min_chunk) { std::array x_shape{1, token_num_in_chunk}; // shape of x and y are same Ort::Value x = Ort::Value::CreateTensor( allocator, x_shape.data(), x_shape.size()); Ort::Value y = Ort::Value::CreateTensor( allocator, x_shape.data(), x_shape.size()); int64_t *p_x = x.GetTensorMutableData(); int64_t *p_y = y.GetTensorMutableData(); std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, p_x); std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(), p_y); // streaming forward by NN LM auto out = Rescore(std::move(x), std::move(y), Convert(std::move(h.nn_lm_states))); // update NN LM score in hyp const float *p_nll = out.first.GetTensorData(); h.lm_log_prob = -scale * (*p_nll); // update NN LM states in hyp h.nn_lm_states = Convert(std::move(out.second)); h.cur_scored_pos += token_num_in_chunk; } } } } } // namespace sherpa_onnx