From f0960342ad484273e6de29cf1a5f1101c0141c42 Mon Sep 17 00:00:00 2001 From: Askars Salimbajevs <19928242+vsd-vector@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:23:46 +0300 Subject: [PATCH] Add LODR support to online and offline recognizers (#2026) This PR integrates LODR (Level-Ordered Deterministic Rescoring) support from Icefall into both online and offline recognizers, enabling LODR for LM shallow fusion and LM rescore. - Extended OnlineLMConfig and OfflineLMConfig to include lodr_fst, lodr_scale, and lodr_backoff_id. - Implemented LodrFst and LodrStateCost classes and wired them into RNN LM scoring in both online and offline code paths. - Updated Python bindings, CLI entry points, examples, and CI test scripts to accept and exercise the new LODR options. --- .github/scripts/test-offline-transducer.sh | 34 +++- .github/scripts/test-online-transducer.sh | 55 ++++- .github/scripts/test-python.sh | 32 ++- python-api-examples/offline-decode-files.py | 56 +++++ python-api-examples/online-decode-files.py | 34 ++++ sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/hypothesis.h | 5 + sherpa-onnx/csrc/lodr-fst.cc | 191 ++++++++++++++++++ sherpa-onnx/csrc/lodr-fst.h | 75 +++++++ sherpa-onnx/csrc/offline-lm-config.cc | 14 +- sherpa-onnx/csrc/offline-lm-config.h | 13 +- sherpa-onnx/csrc/offline-lm.cc | 7 + sherpa-onnx/csrc/offline-lm.h | 17 ++ sherpa-onnx/csrc/offline-rnn-lm.cc | 4 +- sherpa-onnx/csrc/online-lm-config.cc | 12 ++ sherpa-onnx/csrc/online-lm-config.h | 12 +- sherpa-onnx/csrc/online-rnn-lm.cc | 29 +++ sherpa-onnx/python/csrc/offline-lm-config.cc | 10 +- sherpa-onnx/python/csrc/online-lm-config.cc | 10 +- .../python/sherpa_onnx/offline_recognizer.py | 8 + .../python/sherpa_onnx/online_recognizer.py | 8 + 21 files changed, 613 insertions(+), 14 deletions(-) create mode 100644 sherpa-onnx/csrc/lodr-fst.cc create mode 100644 sherpa-onnx/csrc/lodr-fst.h diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index d0544847..d4207dec 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -281,7 +281,39 @@ time $EXE \ $repo/test_wavs/1.wav \ $repo/test_wavs/8k.wav -rm -rf $repo +lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm +log "Download pre-trained RNN-LM model from ${lm_repo_url}" +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url +lm_repo=$(basename $lm_repo_url) +pushd $lm_repo +git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx" +popd + +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26 +log "Download bi-gram LM from ${bigram_repo_url}" +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url +bigramlm_repo=$(basename $bigram_repo_url) +pushd $bigramlm_repo +git lfs pull --include "2gram.fst" +popd + +log "Start testing with LM and bi-gram LODR" +# TODO: find test examples that change with the LODR +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --num-threads=2 \ + --decoding_method="modified_beam_search" \ + --lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \ + --lodr-fst=$bigramlm_repo/2gram.fst \ + --lodr-scale=-0.5 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo $lm_repo $bigramlm_repo log "------------------------------------------------------------" log "Run Paraformer (Chinese)" diff --git a/.github/scripts/test-online-transducer.sh b/.github/scripts/test-online-transducer.sh index ceb2be47..b2af34cf 100755 --- a/.github/scripts/test-online-transducer.sh +++ b/.github/scripts/test-online-transducer.sh @@ -174,7 +174,60 @@ for wave in ${waves[@]}; do $wave done -rm -rf $repo +lm_repo_url=https://huggingface.co/vsd-vector/icefall-librispeech-rnn-lm +log "Download pre-trained RNN-LM model from ${lm_repo_url}" +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url +lm_repo=$(basename $lm_repo_url) +pushd $lm_repo +git lfs pull --include "with-state-epoch-99-avg-1.onnx" +popd + +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26 +log "Download bi-gram LM from ${bigram_repo_url}" +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url +bigramlm_repo=$(basename $bigram_repo_url) +pushd $bigramlm_repo +git lfs pull --include "2gram.fst" +popd + +log "Start testing LODR" + +waves=( +$repo/test_wavs/0.wav +$repo/test_wavs/1.wav +$repo/test_wavs/8k.wav +) + +for wave in ${waves[@]}; do + time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --num-threads=2 \ + --decoding_method="modified_beam_search" \ + --lm=$lm_repo/with-state-epoch-99-avg-1.onnx \ + --lodr-fst=$bigramlm_repo/2gram.fst \ + --lodr-scale=-0.5 \ + $wave +done + +for wave in ${waves[@]}; do + time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --num-threads=2 \ + --decoding_method="modified_beam_search" \ + --lm=$lm_repo/with-state-epoch-99-avg-1.onnx \ + --lodr-fst=$bigramlm_repo/2gram.fst \ + --lodr-scale=-0.5 \ + --lm-shallow-fusion=true \ + $wave +done + +rm -rf $repo $bigramlm_repo $lm_repo log "------------------------------------------------------------" log "Run streaming Zipformer transducer (Bilingual, Chinese + English)" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 08eb11de..80f78106 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \ $repo/test_wavs/1.wav \ $repo/test_wavs/8k.wav +lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm +log "Download pre-trained RNN-LM model from ${lm_repo_url}" +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url +lm_repo=$(basename $lm_repo_url) +pushd $lm_repo +git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx" +popd + +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26 +log "Download bi-gram LM from ${bigram_repo_url}" +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url +bigramlm_repo=$(basename $bigram_repo_url) +pushd $bigramlm_repo +git lfs pull --include "2gram.fst" +popd + +log "Perform offline decoding with RNN-LM and LODR" +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --decoding-method=modified_beam_search \ + --lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \ + --lodr-fst=$bigramlm_repo/2gram.fst \ + --lodr-scale=-0.5 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose -rm -rf $repo +rm -rf $repo $lm_repo $bigramlm_repo log "Test non-streaming paraformer models" diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index 0f87284e..d37ad935 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -35,6 +35,25 @@ file(s) with a non-streaming model. /path/to/0.wav \ /path/to/1.wav + also with RNN LM rescoring and LODR (optional): + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=modified_beam_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + --lm=/path/to/lm.onnx \ + --lm-scale=0.1 \ + --lodr-fst=/path/to/lodr.fst \ + --lodr-scale=-0.1 \ + /path/to/0.wav \ + /path/to/1.wav + (3) For CTC models from NeMo python3 ./python-api-examples/offline-decode-files.py \ @@ -269,6 +288,39 @@ def get_args(): default="greedy_search", help="Valid values are greedy_search and modified_beam_search", ) + + parser.add_argument( + "--lm", + metavar="file", + type=str, + default="", + help="Path to RNN LM model", + ) + + parser.add_argument( + "--lm-scale", + metavar="lm_scale", + type=float, + default=0.1, + help="LM model scale for rescoring", + ) + + parser.add_argument( + "--lodr-fst", + metavar="file", + type=str, + default="", + help="Path to LODR FST model. Used only when --lm is given.", + ) + + parser.add_argument( + "--lodr-scale", + metavar="lodr_scale", + type=float, + default=-0.1, + help="LODR scale for rescoring.Used only when --lodr_fst is given.", + ) + parser.add_argument( "--debug", type=bool, @@ -364,6 +416,10 @@ def main(): num_threads=args.num_threads, sample_rate=args.sample_rate, feature_dim=args.feature_dim, + lm=args.lm, + lm_scale=args.lm_scale, + lodr_fst=args.lodr_fst, + lodr_scale=args.lodr_scale, decoding_method=args.decoding_method, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index 0188f049..586741ff 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav +or with RNN LM rescoring and LODR: + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoding-method=modified_beam_search \ + --lm=/path/to/lm.onnx \ + --lm-scale=0.1 \ + --lodr-fst=/path/to/lodr.fst \ + --lodr-scale=-0.1 \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav + (2) Streaming paraformer curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 @@ -186,6 +202,22 @@ def get_args(): """, ) + parser.add_argument( + "--lodr-fst", + metavar="file", + type=str, + default="", + help="Path to LODR FST model. Used only when --lm is given.", + ) + + parser.add_argument( + "--lodr-scale", + metavar="lodr_scale", + type=float, + default=-0.1, + help="LODR scale for rescoring.Used only when --lodr_fst is given.", + ) + parser.add_argument( "--provider", type=str, @@ -320,6 +352,8 @@ def main(): max_active_paths=args.max_active_paths, lm=args.lm, lm_scale=args.lm_scale, + lodr_fst=args.lodr_fst, + lodr_scale=args.lodr_scale, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, modeling_unit=args.modeling_unit, diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 37d1d869..33cd9939 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -25,6 +25,7 @@ set(sources jieba.cc keyword-spotter-impl.cc keyword-spotter.cc + lodr-fst.cc offline-canary-model-config.cc offline-canary-model.cc offline-ctc-fst-decoder-config.cc diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 428a74fa..474723fa 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -12,9 +12,11 @@ #include #include #include +#include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/context-graph.h" +#include "sherpa-onnx/csrc/lodr-fst.h" #include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -61,6 +63,9 @@ struct Hypothesis { // the nn lm states std::vector nn_lm_states; + // the LODR states + std::shared_ptr lodr_state; + const ContextState *context_state; // TODO(fangjun): Make it configurable diff --git a/sherpa-onnx/csrc/lodr-fst.cc b/sherpa-onnx/csrc/lodr-fst.cc new file mode 100644 index 00000000..a5d0d218 --- /dev/null +++ b/sherpa-onnx/csrc/lodr-fst.cc @@ -0,0 +1,191 @@ +// sherpa-onnx/csrc/lodr-fst.cc +// +// Contains code copied from icefall/utils/ngram_lm.py +// Copyright (c) 2023 Xiaomi Corporation +// +// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs) + +#include +#include +#include + +#include "sherpa-onnx/csrc/lodr-fst.h" +#include "sherpa-onnx/csrc/log.h" +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +int32_t LodrFst::FindBackoffId() { + // assume that the backoff id is the only input label with epsilon output + + for (int32_t state = 0; state < fst_->NumStates(); ++state) { + fst::ArcIterator arc_iter(*fst_, state); + for ( ; !arc_iter.Done(); arc_iter.Next()) { + const auto& arc = arc_iter.Value(); + if (arc.olabel == 0) { // Check if the output label is epsilon (0) + return arc.ilabel; // Return the input label + } + } + } + + return -1; // Return -1 if no such input symbol is found +} + +LodrFst::LodrFst(const std::string &fst_path, int32_t backoff_id) + : backoff_id_(backoff_id) { + fst_ = std::unique_ptr( + CastOrConvertToConstFst(fst::StdVectorFst::Read(fst_path))); + + if (backoff_id < 0) { + // backoff_id_ is not provided, find it automatically + backoff_id_ = FindBackoffId(); + if (backoff_id_ < 0) { + std::string err_msg = "Failed to initialize LODR: No backoff arc found"; + SHERPA_ONNX_LOGE("%s", err_msg.c_str()); + SHERPA_ONNX_EXIT(-1); + } + } +} + +std::vector> LodrFst::ProcessBackoffArcs( + int32_t state, float cost) { + std::vector> ans; + auto next = GetNextStatesCostsNoBackoff(state, backoff_id_); + if (!next.has_value()) { + return ans; + } + auto [next_state, next_cost] = next.value(); + ans.emplace_back(next_state, next_cost + cost); + auto recursive_result = ProcessBackoffArcs(next_state, next_cost + cost); + ans.insert(ans.end(), recursive_result.begin(), recursive_result.end()); + return ans; +} + +std::optional> LodrFst::GetNextStatesCostsNoBackoff( + int32_t state, int32_t label) { + fst::ArcIterator arc_iter(*fst_, state); + int32_t num_arcs = fst_->NumArcs(state); + + int32_t left = 0, right = num_arcs - 1; + while (left <= right) { + int32_t mid = (left + right) / 2; + arc_iter.Seek(mid); + auto arc = arc_iter.Value(); + if (arc.ilabel < label) { + left = mid + 1; + } else if (arc.ilabel > label) { + right = mid - 1; + } else { + return std::make_tuple(arc.nextstate, arc.weight.Value()); + } + } + return std::nullopt; +} + +std::pair, std::vector> LodrFst::GetNextStateCosts( + int32_t state, int32_t label) { + std::vector states = {state}; + std::vector costs = {0}; + + auto extra_states_costs = ProcessBackoffArcs(state, 0); + for (const auto& [s, c] : extra_states_costs) { + states.push_back(s); + costs.push_back(c); + } + + std::vector next_states; + std::vector next_costs; + for (size_t i = 0; i < states.size(); ++i) { + auto next = GetNextStatesCostsNoBackoff(states[i], label); + if (next.has_value()) { + auto [ns, nc] = next.value(); + next_states.push_back(ns); + next_costs.push_back(costs[i] + nc); + } + } + + return std::make_pair(next_states, next_costs); +} + +void LodrFst::ComputeScore(float scale, Hypothesis *hyp, int32_t offset) { + if (scale == 0) { + return; + } + + hyp->lodr_state = std::make_unique(this); + + // Walk through the FST with the input text from the hypothesis + for (size_t i = offset; i < hyp->ys.size(); ++i) { + *hyp->lodr_state = hyp->lodr_state->ForwardOneStep(hyp->ys[i]); + } + + float lodr_score = hyp->lodr_state->FinalScore(); + + if (lodr_score == -std::numeric_limits::infinity()) { + SHERPA_ONNX_LOGE("Failed to compute LODR. Empty or mismatched FST?"); + return; + } + + // Update the hyp score + hyp->log_prob += scale * lodr_score; +} + +float LodrFst::GetFinalCost(int32_t state) { + auto final_weight = fst_->Final(state); + if (final_weight == fst::StdArc::Weight::Zero()) { + return 0.0; + } + return final_weight.Value(); +} + +LodrStateCost::LodrStateCost( + LodrFst* fst, const std::unordered_map &state_cost) + : fst_(fst) { + if (state_cost.empty()) { + state_cost_[0] = 0.0; + } else { + state_cost_ = state_cost; + } +} + +LodrStateCost LodrStateCost::ForwardOneStep(int32_t label) { + std::unordered_map state_cost; + for (const auto& [s, c] : state_cost_) { + auto [next_states, next_costs] = fst_->GetNextStateCosts(s, label); + for (size_t i = 0; i < next_states.size(); ++i) { + int32_t ns = next_states[i]; + float nc = next_costs[i]; + if (state_cost.find(ns) == state_cost.end()) { + state_cost[ns] = std::numeric_limits::infinity(); + } + state_cost[ns] = std::min(state_cost[ns], c + nc); + } + } + return LodrStateCost(fst_, state_cost); +} + +float LodrStateCost::Score() const { + if (state_cost_.empty()) { + return -std::numeric_limits::infinity(); + } + auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(), + [](const auto& a, const auto& b) { + return a.second < b.second; + }); + return -min_cost->second; +} + +float LodrStateCost::FinalScore() const { + if (state_cost_.empty()) { + return -std::numeric_limits::infinity(); + } + auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(), + [](const auto& a, const auto& b) { + return a.second < b.second; + }); + return -(min_cost->second + + fst_->GetFinalCost(min_cost->first)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/lodr-fst.h b/sherpa-onnx/csrc/lodr-fst.h new file mode 100644 index 00000000..74adcfd3 --- /dev/null +++ b/sherpa-onnx/csrc/lodr-fst.h @@ -0,0 +1,75 @@ +// sherpa-onnx/csrc/lodr-fst.h +// +// Contains code copied from icefall/utils/ngram_lm.py +// Copyright (c) 2023 Xiaomi Corporation +// +// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs) + + +#ifndef SHERPA_ONNX_CSRC_LODR_FST_H_ +#define SHERPA_ONNX_CSRC_LODR_FST_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kaldifst/csrc/kaldi-fst-io.h" + +namespace sherpa_onnx { + +class Hypothesis; + +class LodrFst { + public: + explicit LodrFst(const std::string &fst_path, int32_t backoff_id = -1); + + std::pair, std::vector> GetNextStateCosts( + int32_t state, int32_t label); + + float GetFinalCost(int32_t state); + + void ComputeScore(float scale, Hypothesis *hyp, int32_t offset); + + private: + fst::StdVectorFst YsToFst(const std::vector &ys, int32_t offset); + + std::vector> ProcessBackoffArcs( + int32_t state, float cost); + + std::optional> GetNextStatesCostsNoBackoff( + int32_t state, int32_t label); + + int32_t FindBackoffId(); + + + int32_t backoff_id_ = -1; + std::unique_ptr fst_; // owned by this class +}; + +class LodrStateCost { + public: + explicit LodrStateCost( + LodrFst* fst, + const std::unordered_map &state_cost = {}); + + LodrStateCost ForwardOneStep(int32_t label); + + float Score() const; + float FinalScore() const; + + private: + // The fst_ is not owned by this class and borrowed from the caller + // (e.g. OnlineRnnLM). + LodrFst* fst_; + std::unordered_map state_cost_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_LODR_FST_H_ diff --git a/sherpa-onnx/csrc/offline-lm-config.cc b/sherpa-onnx/csrc/offline-lm-config.cc index 791fa11a..b9528d0b 100644 --- a/sherpa-onnx/csrc/offline-lm-config.cc +++ b/sherpa-onnx/csrc/offline-lm-config.cc @@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) { "Number of threads to run the neural network of LM model"); po->Register("lm-provider", &lm_provider, "Specify a provider to LM model use: cpu, cuda, coreml"); + po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model."); + po->Register("lodr-scale", &lodr_scale, "LODR scale."); + po->Register("lodr-backoff-id", &lodr_backoff_id, + "ID of the backoff in the LODR FST. -1 means autodetect"); } bool OfflineLMConfig::Validate() const { @@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const { return false; } + if (!lodr_fst.empty() && !FileExists(lodr_fst)) { + SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str()); + return false; + } + return true; } @@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const { os << "OfflineLMConfig("; os << "model=\"" << model << "\", "; - os << "scale=" << scale << ")"; + os << "scale=" << scale << ", "; + os << "lodr_scale=" << lodr_scale << ", "; + os << "lodr_fst=\"" << lodr_fst << "\", "; + os << "lodr_backoff_id=" << lodr_backoff_id << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-lm-config.h b/sherpa-onnx/csrc/offline-lm-config.h index 3468c58a..839a67c8 100644 --- a/sherpa-onnx/csrc/offline-lm-config.h +++ b/sherpa-onnx/csrc/offline-lm-config.h @@ -19,14 +19,23 @@ struct OfflineLMConfig { int32_t lm_num_threads = 1; std::string lm_provider = "cpu"; + // LODR + std::string lodr_fst; + float lodr_scale = 0.01; + int32_t lodr_backoff_id = -1; // -1 means not set + OfflineLMConfig() = default; OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, - const std::string &lm_provider) + const std::string &lm_provider, const std::string &lodr_fst, + float lodr_scale, int32_t lodr_backoff_id) : model(model), scale(scale), lm_num_threads(lm_num_threads), - lm_provider(lm_provider) {} + lm_provider(lm_provider), + lodr_fst(lodr_fst), + lodr_scale(lodr_scale), + lodr_backoff_id(lodr_backoff_id) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-lm.cc b/sherpa-onnx/csrc/offline-lm.cc index 0c42b7ff..a452915c 100644 --- a/sherpa-onnx/csrc/offline-lm.cc +++ b/sherpa-onnx/csrc/offline-lm.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/lodr-fst.h" #include "sherpa-onnx/csrc/offline-rnn-lm.h" namespace sherpa_onnx { @@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size, } auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); const float *p_nll = negative_loglike.GetTensorData(); + // We scale LODR scale with LM scale to replicate Icefall code + auto lodr_scale = config_.lodr_scale * scale; for (auto &h : *hyps) { for (auto &t : h) { // Use -scale here since we want to change negative loglike to loglike. t.second.lm_log_prob = -scale * (*p_nll); ++p_nll; + // apply LODR to hyp score + if (lodr_fst_ != nullptr) { + lodr_fst_->ComputeScore(lodr_scale, &t.second, context_size); + } } } } diff --git a/sherpa-onnx/csrc/offline-lm.h b/sherpa-onnx/csrc/offline-lm.h index a9af8202..e7c0937a 100644 --- a/sherpa-onnx/csrc/offline-lm.h +++ b/sherpa-onnx/csrc/offline-lm.h @@ -10,12 +10,24 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/lodr-fst.h" #include "sherpa-onnx/csrc/offline-lm-config.h" namespace sherpa_onnx { class OfflineLM { public: + explicit OfflineLM(const OfflineLMConfig &config) : config_(config) { + if (!config_.lodr_fst.empty()) { + try { + lodr_fst_ = std::make_unique(LodrFst(config_.lodr_fst, + config_.lodr_backoff_id)); + } catch (const std::exception& e) { + throw std::runtime_error("Failed to load LODR FST from: " + + config_.lodr_fst + ". Error: " + e.what()); + } + } + } virtual ~OfflineLM() = default; static std::unique_ptr Create(const OfflineLMConfig &config); @@ -43,6 +55,11 @@ class OfflineLM { // @param hyps It is changed in-place. void ComputeLMScore(float scale, int32_t context_size, std::vector *hyps); + + private: + std::unique_ptr lodr_fst_; + float lodr_scale_; + OfflineLMConfig config_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-rnn-lm.cc b/sherpa-onnx/csrc/offline-rnn-lm.cc index 8f9425da..bdc2f903 100644 --- a/sherpa-onnx/csrc/offline-rnn-lm.cc +++ b/sherpa-onnx/csrc/offline-rnn-lm.cc @@ -83,11 +83,11 @@ class OfflineRnnLM::Impl { }; OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) - : impl_(std::make_unique(config)) {} + : impl_(std::make_unique(config)), OfflineLM(config) {} template OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config) - : impl_(std::make_unique(mgr, config)) {} + : impl_(std::make_unique(mgr, config)), OfflineLM(config) {} OfflineRnnLM::~OfflineRnnLM() = default; diff --git a/sherpa-onnx/csrc/online-lm-config.cc b/sherpa-onnx/csrc/online-lm-config.cc index 9611c7f3..43c84032 100644 --- a/sherpa-onnx/csrc/online-lm-config.cc +++ b/sherpa-onnx/csrc/online-lm-config.cc @@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) { "Specify a provider to LM model use: cpu, cuda, coreml"); po->Register("lm-shallow-fusion", &shallow_fusion, "Boolean whether to use shallow fusion or rescore."); + po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model."); + po->Register("lodr-scale", &lodr_scale, "LODR scale."); + po->Register("lodr-backoff-id", &lodr_backoff_id, + "ID of the backoff in the LODR FST. -1 means autodetect"); } bool OnlineLMConfig::Validate() const { @@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const { return false; } + if (!lodr_fst.empty() && !FileExists(lodr_fst)) { + SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str()); + return false; + } + return true; } @@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const { os << "OnlineLMConfig("; os << "model=\"" << model << "\", "; os << "scale=" << scale << ", "; + os << "lodr_scale=" << lodr_scale << ", "; + os << "lodr_fst=\"" << lodr_fst << "\", "; + os << "lodr_backoff_id=" << lodr_backoff_id << ", "; os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/online-lm-config.h b/sherpa-onnx/csrc/online-lm-config.h index 8d5b1670..f3c92cf3 100644 --- a/sherpa-onnx/csrc/online-lm-config.h +++ b/sherpa-onnx/csrc/online-lm-config.h @@ -18,18 +18,26 @@ struct OnlineLMConfig { float scale = 0.5; int32_t lm_num_threads = 1; std::string lm_provider = "cpu"; + std::string lodr_fst; + float lodr_scale = 0.01; + int32_t lodr_backoff_id = -1; // -1 means not set // enable shallow fusion bool shallow_fusion = true; OnlineLMConfig() = default; OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, - const std::string &lm_provider, bool shallow_fusion) + const std::string &lm_provider, bool shallow_fusion, + const std::string &lodr_fst, float lodr_scale, + int32_t lodr_backoff_id) : model(model), scale(scale), lm_num_threads(lm_num_threads), lm_provider(lm_provider), - shallow_fusion(shallow_fusion) {} + shallow_fusion(shallow_fusion), + lodr_fst(lodr_fst), + lodr_scale(lodr_scale), + lodr_backoff_id(lodr_backoff_id) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 4e5261ce..8c5b1b45 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -12,6 +12,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/lodr-fst.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" @@ -35,12 +36,27 @@ class OnlineRnnLM::Impl { auto init_states = GetInitStatesSF(); hyp->nn_lm_scores.value = std::move(init_states.first); hyp->nn_lm_states = Convert(std::move(init_states.second)); + // if LODR enabled, we need to initialize the LODR state + if (lodr_fst_ != nullptr) { + hyp->lodr_state = std::make_unique(lodr_fst_.get()); + } } // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData(); hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; + // if LODR enabled, we need to update the LODR state + if (lodr_fst_ != nullptr) { + auto next_lodr_state = std::make_unique( + hyp->lodr_state->ForwardOneStep(hyp->ys.back())); + // calculate the score of the latest token + auto score = next_lodr_state->Score() - hyp->lodr_state->Score(); + hyp->lodr_state = std::move(next_lodr_state); + // apply LODR to hyp score + hyp->lm_log_prob += score * config_.lodr_scale; + } + // get lm scores for next tokens given the hyp->ys[:] and save to // nn_lm_scores std::array x_shape{1, 1}; @@ -89,6 +105,12 @@ class OnlineRnnLM::Impl { const float *p_nll = out.first.GetTensorData(); h.lm_log_prob = -scale * (*p_nll); + // apply LODR to hyp score + if (lodr_fst_ != nullptr) { + // We scale LODR scale with LM scale to replicate Icefall code + lodr_fst_->ComputeScore(config_.lodr_scale*scale, &h, context_size); + } + // update NN LM states in hyp h.nn_lm_states = Convert(std::move(out.second)); @@ -154,6 +176,11 @@ class OnlineRnnLM::Impl { SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); ComputeInitStates(); + + if (!config_.lodr_fst.empty()) { + lodr_fst_ = std::make_unique(LodrFst(config_.lodr_fst, + config_.lodr_backoff_id)); + } } void ComputeInitStates() { @@ -203,6 +230,8 @@ class OnlineRnnLM::Impl { int32_t rnn_num_layers_ = 2; int32_t rnn_hidden_size_ = 512; int32_t sos_id_ = 1; + + std::unique_ptr lodr_fst_; }; OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) diff --git a/sherpa-onnx/python/csrc/offline-lm-config.cc b/sherpa-onnx/python/csrc/offline-lm-config.cc index 42edb656..c665fb43 100644 --- a/sherpa-onnx/python/csrc/offline-lm-config.cc +++ b/sherpa-onnx/python/csrc/offline-lm-config.cc @@ -13,13 +13,19 @@ namespace sherpa_onnx { void PybindOfflineLMConfig(py::module *m) { using PyClass = OfflineLMConfig; py::class_(*m, "OfflineLMConfig") - .def(py::init(), + .def(py::init(), py::arg("model"), py::arg("scale") = 0.5f, - py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", + py::arg("lodr_fst") = "", py::arg("lodr_scale") = 0.0f, + py::arg("lodr_backoff_id") = -1) .def_readwrite("model", &PyClass::model) .def_readwrite("scale", &PyClass::scale) .def_readwrite("lm_provider", &PyClass::lm_provider) .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) + .def_readwrite("lodr_fst", &PyClass::lodr_fst) + .def_readwrite("lodr_scale", &PyClass::lodr_scale) + .def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-lm-config.cc b/sherpa-onnx/python/csrc/online-lm-config.cc index 0e9a0385..a01056f7 100644 --- a/sherpa-onnx/python/csrc/online-lm-config.cc +++ b/sherpa-onnx/python/csrc/online-lm-config.cc @@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) { using PyClass = OnlineLMConfig; py::class_(*m, "OnlineLMConfig") .def(py::init(), + const std::string &, bool, const std::string &, + float, int>(), py::arg("model") = "", py::arg("scale") = 0.5f, py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", - py::arg("shallow_fusion") = true) + py::arg("shallow_fusion") = true, py::arg("lodr_fst") = "", + py::arg("lodr_scale") = 0.0f, py::arg("lodr_backoff_id") = -1) .def_readwrite("model", &PyClass::model) .def_readwrite("scale", &PyClass::scale) .def_readwrite("lm_provider", &PyClass::lm_provider) .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) + .def_readwrite("lodr_fst", &PyClass::lodr_fst) + .def_readwrite("lodr_scale", &PyClass::lodr_scale) + .def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id) + .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index b8586d26..bc5714c6 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -69,6 +69,8 @@ class OfflineRecognizer(object): hr_dict_dir: str = "", hr_rule_fsts: str = "", hr_lexicon: str = "", + lodr_fst: str = "", + lodr_scale: float = 0.0, ): """ Please refer to @@ -133,6 +135,10 @@ class OfflineRecognizer(object): rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + lodr_fst: + Path to the LODR FST file in binary format. If empty, LODR is disabled. + lodr_scale: + Scale factor for LODR rescoring. Only used when lodr_fst is provided. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -173,6 +179,8 @@ class OfflineRecognizer(object): scale=lm_scale, lm_num_threads=num_threads, lm_provider=provider, + lodr_fst=lodr_fst, + lodr_scale=lodr_scale, ) recognizer_config = OfflineRecognizerConfig( diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 747e4e50..09c507b0 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -89,6 +89,8 @@ class OnlineRecognizer(object): hr_dict_dir: str = "", hr_rule_fsts: str = "", hr_lexicon: str = "", + lodr_fst: str = "", + lodr_scale: float = 0.0, ): """ Please refer to @@ -216,6 +218,10 @@ class OnlineRecognizer(object): "Set path for storing timing cache." TensorRT EP trt_dump_subgraphs: bool = False, "Dump optimized subgraphs for debugging." TensorRT EP + lodr_fst: + Path to the LODR FST file in binary format. If empty, LODR is disabled. + lodr_scale: + Scale factor for LODR rescoring. Only used when lodr_fst is provided. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -298,6 +304,8 @@ class OnlineRecognizer(object): model=lm, scale=lm_scale, shallow_fusion=lm_shallow_fusion, + lodr_fst=lodr_fst, + lodr_scale=lodr_scale, ) recognizer_config = OnlineRecognizerConfig(