// sherpa-onnx/csrc/offline-lm.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-lm.h" #include #include #include #include "sherpa-onnx/csrc/offline-rnn-lm.h" namespace sherpa_onnx { std::unique_ptr OfflineLM::Create(const OfflineLMConfig &config) { return std::make_unique(config); } #if __ANDROID_API__ >= 9 std::unique_ptr OfflineLM::Create(AAssetManager *mgr, const OfflineLMConfig &config) { return std::make_unique(mgr, config); } #endif void OfflineLM::ComputeLMScore(float scale, int32_t context_size, std::vector *hyps) { // compute the max token seq so that we know how much space to allocate int32_t max_token_seq = 0; int32_t num_hyps = 0; // we subtract context_size below since each token sequence is prepended // with context_size blanks for (const auto &h : *hyps) { num_hyps += h.Size(); for (const auto &t : h) { max_token_seq = std::max(max_token_seq, t.second.ys.size() - context_size); } } Ort::AllocatorWithDefaultOptions allocator; std::array x_shape{num_hyps, max_token_seq}; Ort::Value x = Ort::Value::CreateTensor(allocator, x_shape.data(), x_shape.size()); std::array x_lens_shape{num_hyps}; Ort::Value x_lens = Ort::Value::CreateTensor( allocator, x_lens_shape.data(), x_lens_shape.size()); int64_t *p = x.GetTensorMutableData(); std::fill(p, p + num_hyps * max_token_seq, 0); int64_t *p_lens = x_lens.GetTensorMutableData(); for (const auto &h : *hyps) { for (const auto &t : h) { const auto &ys = t.second.ys; int32_t len = ys.size() - context_size; std::copy(ys.begin() + context_size, ys.end(), p); *p_lens = len; p += max_token_seq; ++p_lens; } } auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); const float *p_nll = negative_loglike.GetTensorData(); for (auto &h : *hyps) { for (auto &t : h) { // Use -scale here since we want to change negative loglike to loglike. t.second.lm_log_prob = -scale * (*p_nll); ++p_nll; } } } } // namespace sherpa_onnx