Add LODR support to online and offline recognizers (#2026)
This PR integrates LODR (Level-Ordered Deterministic Rescoring) support from Icefall into both online and offline recognizers, enabling LODR for LM shallow fusion and LM rescore. - Extended OnlineLMConfig and OfflineLMConfig to include lodr_fst, lodr_scale, and lodr_backoff_id. - Implemented LodrFst and LodrStateCost classes and wired them into RNN LM scoring in both online and offline code paths. - Updated Python bindings, CLI entry points, examples, and CI test scripts to accept and exercise the new LODR options.
This commit is contained in:
committed by
GitHub
parent
6122a678f5
commit
f0960342ad
34
.github/scripts/test-offline-transducer.sh
vendored
34
.github/scripts/test-offline-transducer.sh
vendored
@@ -281,7 +281,39 @@ time $EXE \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
|
||||
lm_repo=$(basename $lm_repo_url)
|
||||
pushd $lm_repo
|
||||
git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
|
||||
popd
|
||||
|
||||
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
|
||||
log "Download bi-gram LM from ${bigram_repo_url}"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
|
||||
bigramlm_repo=$(basename $bigram_repo_url)
|
||||
pushd $bigramlm_repo
|
||||
git lfs pull --include "2gram.fst"
|
||||
popd
|
||||
|
||||
log "Start testing with LM and bi-gram LODR"
|
||||
# TODO: find test examples that change with the LODR
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding_method="modified_beam_search" \
|
||||
--lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
|
||||
--lodr-fst=$bigramlm_repo/2gram.fst \
|
||||
--lodr-scale=-0.5 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo $lm_repo $bigramlm_repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Paraformer (Chinese)"
|
||||
|
||||
55
.github/scripts/test-online-transducer.sh
vendored
55
.github/scripts/test-online-transducer.sh
vendored
@@ -174,7 +174,60 @@ for wave in ${waves[@]}; do
|
||||
$wave
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
lm_repo_url=https://huggingface.co/vsd-vector/icefall-librispeech-rnn-lm
|
||||
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
|
||||
lm_repo=$(basename $lm_repo_url)
|
||||
pushd $lm_repo
|
||||
git lfs pull --include "with-state-epoch-99-avg-1.onnx"
|
||||
popd
|
||||
|
||||
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
|
||||
log "Download bi-gram LM from ${bigram_repo_url}"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
|
||||
bigramlm_repo=$(basename $bigram_repo_url)
|
||||
pushd $bigramlm_repo
|
||||
git lfs pull --include "2gram.fst"
|
||||
popd
|
||||
|
||||
log "Start testing LODR"
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/8k.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding_method="modified_beam_search" \
|
||||
--lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
|
||||
--lodr-fst=$bigramlm_repo/2gram.fst \
|
||||
--lodr-scale=-0.5 \
|
||||
$wave
|
||||
done
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding_method="modified_beam_search" \
|
||||
--lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
|
||||
--lodr-fst=$bigramlm_repo/2gram.fst \
|
||||
--lodr-scale=-0.5 \
|
||||
--lm-shallow-fusion=true \
|
||||
$wave
|
||||
done
|
||||
|
||||
rm -rf $repo $bigramlm_repo $lm_repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run streaming Zipformer transducer (Bilingual, Chinese + English)"
|
||||
|
||||
32
.github/scripts/test-python.sh
vendored
32
.github/scripts/test-python.sh
vendored
@@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
|
||||
lm_repo=$(basename $lm_repo_url)
|
||||
pushd $lm_repo
|
||||
git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
|
||||
popd
|
||||
|
||||
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
|
||||
log "Download bi-gram LM from ${bigram_repo_url}"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
|
||||
bigramlm_repo=$(basename $bigram_repo_url)
|
||||
pushd $bigramlm_repo
|
||||
git lfs pull --include "2gram.fst"
|
||||
popd
|
||||
|
||||
log "Perform offline decoding with RNN-LM and LODR"
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--decoding-method=modified_beam_search \
|
||||
--lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
|
||||
--lodr-fst=$bigramlm_repo/2gram.fst \
|
||||
--lodr-scale=-0.5 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||
|
||||
rm -rf $repo
|
||||
rm -rf $repo $lm_repo $bigramlm_repo
|
||||
|
||||
log "Test non-streaming paraformer models"
|
||||
|
||||
|
||||
@@ -35,6 +35,25 @@ file(s) with a non-streaming model.
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
also with RNN LM rescoring and LODR (optional):
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=modified_beam_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
--lm=/path/to/lm.onnx \
|
||||
--lm-scale=0.1 \
|
||||
--lodr-fst=/path/to/lodr.fst \
|
||||
--lodr-scale=-0.1 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(3) For CTC models from NeMo
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
@@ -269,6 +288,39 @@ def get_args():
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
metavar="file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to RNN LM model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
metavar="lm_scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="LM model scale for rescoring",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lodr-fst",
|
||||
metavar="file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to LODR FST model. Used only when --lm is given.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lodr-scale",
|
||||
metavar="lodr_scale",
|
||||
type=float,
|
||||
default=-0.1,
|
||||
help="LODR scale for rescoring.Used only when --lodr_fst is given.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
@@ -364,6 +416,10 @@ def main():
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
lm=args.lm,
|
||||
lm_scale=args.lm_scale,
|
||||
lodr_fst=args.lodr_fst,
|
||||
lodr_scale=args.lodr_scale,
|
||||
decoding_method=args.decoding_method,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
|
||||
@@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2
|
||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
|
||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
|
||||
|
||||
or with RNN LM rescoring and LODR:
|
||||
|
||||
./python-api-examples/online-decode-files.py \
|
||||
--tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
|
||||
--encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||
--decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||
--joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||
--decoding-method=modified_beam_search \
|
||||
--lm=/path/to/lm.onnx \
|
||||
--lm-scale=0.1 \
|
||||
--lodr-fst=/path/to/lodr.fst \
|
||||
--lodr-scale=-0.1 \
|
||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
|
||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
|
||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
|
||||
|
||||
(2) Streaming paraformer
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
|
||||
@@ -186,6 +202,22 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lodr-fst",
|
||||
metavar="file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to LODR FST model. Used only when --lm is given.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lodr-scale",
|
||||
metavar="lodr_scale",
|
||||
type=float,
|
||||
default=-0.1,
|
||||
help="LODR scale for rescoring.Used only when --lodr_fst is given.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
@@ -320,6 +352,8 @@ def main():
|
||||
max_active_paths=args.max_active_paths,
|
||||
lm=args.lm,
|
||||
lm_scale=args.lm_scale,
|
||||
lodr_fst=args.lodr_fst,
|
||||
lodr_scale=args.lodr_scale,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
modeling_unit=args.modeling_unit,
|
||||
|
||||
@@ -25,6 +25,7 @@ set(sources
|
||||
jieba.cc
|
||||
keyword-spotter-impl.cc
|
||||
keyword-spotter.cc
|
||||
lodr-fst.cc
|
||||
offline-canary-model-config.cc
|
||||
offline-canary-model.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
|
||||
@@ -12,9 +12,11 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/lodr-fst.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -61,6 +63,9 @@ struct Hypothesis {
|
||||
// the nn lm states
|
||||
std::vector<CopyableOrtValue> nn_lm_states;
|
||||
|
||||
// the LODR states
|
||||
std::shared_ptr<LodrStateCost> lodr_state;
|
||||
|
||||
const ContextState *context_state;
|
||||
|
||||
// TODO(fangjun): Make it configurable
|
||||
|
||||
191
sherpa-onnx/csrc/lodr-fst.cc
Normal file
191
sherpa-onnx/csrc/lodr-fst.cc
Normal file
@@ -0,0 +1,191 @@
|
||||
// sherpa-onnx/csrc/lodr-fst.cc
|
||||
//
|
||||
// Contains code copied from icefall/utils/ngram_lm.py
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
//
|
||||
// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs)
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/lodr-fst.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
int32_t LodrFst::FindBackoffId() {
|
||||
// assume that the backoff id is the only input label with epsilon output
|
||||
|
||||
for (int32_t state = 0; state < fst_->NumStates(); ++state) {
|
||||
fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state);
|
||||
for ( ; !arc_iter.Done(); arc_iter.Next()) {
|
||||
const auto& arc = arc_iter.Value();
|
||||
if (arc.olabel == 0) { // Check if the output label is epsilon (0)
|
||||
return arc.ilabel; // Return the input label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1; // Return -1 if no such input symbol is found
|
||||
}
|
||||
|
||||
LodrFst::LodrFst(const std::string &fst_path, int32_t backoff_id)
|
||||
: backoff_id_(backoff_id) {
|
||||
fst_ = std::unique_ptr<fst::StdConstFst>(
|
||||
CastOrConvertToConstFst(fst::StdVectorFst::Read(fst_path)));
|
||||
|
||||
if (backoff_id < 0) {
|
||||
// backoff_id_ is not provided, find it automatically
|
||||
backoff_id_ = FindBackoffId();
|
||||
if (backoff_id_ < 0) {
|
||||
std::string err_msg = "Failed to initialize LODR: No backoff arc found";
|
||||
SHERPA_ONNX_LOGE("%s", err_msg.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::tuple<int32_t, float>> LodrFst::ProcessBackoffArcs(
|
||||
int32_t state, float cost) {
|
||||
std::vector<std::tuple<int32_t, float>> ans;
|
||||
auto next = GetNextStatesCostsNoBackoff(state, backoff_id_);
|
||||
if (!next.has_value()) {
|
||||
return ans;
|
||||
}
|
||||
auto [next_state, next_cost] = next.value();
|
||||
ans.emplace_back(next_state, next_cost + cost);
|
||||
auto recursive_result = ProcessBackoffArcs(next_state, next_cost + cost);
|
||||
ans.insert(ans.end(), recursive_result.begin(), recursive_result.end());
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::optional<std::tuple<int32_t, float>> LodrFst::GetNextStatesCostsNoBackoff(
|
||||
int32_t state, int32_t label) {
|
||||
fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state);
|
||||
int32_t num_arcs = fst_->NumArcs(state);
|
||||
|
||||
int32_t left = 0, right = num_arcs - 1;
|
||||
while (left <= right) {
|
||||
int32_t mid = (left + right) / 2;
|
||||
arc_iter.Seek(mid);
|
||||
auto arc = arc_iter.Value();
|
||||
if (arc.ilabel < label) {
|
||||
left = mid + 1;
|
||||
} else if (arc.ilabel > label) {
|
||||
right = mid - 1;
|
||||
} else {
|
||||
return std::make_tuple(arc.nextstate, arc.weight.Value());
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::pair<std::vector<int32_t>, std::vector<float>> LodrFst::GetNextStateCosts(
|
||||
int32_t state, int32_t label) {
|
||||
std::vector<int32_t> states = {state};
|
||||
std::vector<float> costs = {0};
|
||||
|
||||
auto extra_states_costs = ProcessBackoffArcs(state, 0);
|
||||
for (const auto& [s, c] : extra_states_costs) {
|
||||
states.push_back(s);
|
||||
costs.push_back(c);
|
||||
}
|
||||
|
||||
std::vector<int32_t> next_states;
|
||||
std::vector<float> next_costs;
|
||||
for (size_t i = 0; i < states.size(); ++i) {
|
||||
auto next = GetNextStatesCostsNoBackoff(states[i], label);
|
||||
if (next.has_value()) {
|
||||
auto [ns, nc] = next.value();
|
||||
next_states.push_back(ns);
|
||||
next_costs.push_back(costs[i] + nc);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(next_states, next_costs);
|
||||
}
|
||||
|
||||
void LodrFst::ComputeScore(float scale, Hypothesis *hyp, int32_t offset) {
|
||||
if (scale == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
hyp->lodr_state = std::make_unique<LodrStateCost>(this);
|
||||
|
||||
// Walk through the FST with the input text from the hypothesis
|
||||
for (size_t i = offset; i < hyp->ys.size(); ++i) {
|
||||
*hyp->lodr_state = hyp->lodr_state->ForwardOneStep(hyp->ys[i]);
|
||||
}
|
||||
|
||||
float lodr_score = hyp->lodr_state->FinalScore();
|
||||
|
||||
if (lodr_score == -std::numeric_limits<float>::infinity()) {
|
||||
SHERPA_ONNX_LOGE("Failed to compute LODR. Empty or mismatched FST?");
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the hyp score
|
||||
hyp->log_prob += scale * lodr_score;
|
||||
}
|
||||
|
||||
float LodrFst::GetFinalCost(int32_t state) {
|
||||
auto final_weight = fst_->Final(state);
|
||||
if (final_weight == fst::StdArc::Weight::Zero()) {
|
||||
return 0.0;
|
||||
}
|
||||
return final_weight.Value();
|
||||
}
|
||||
|
||||
LodrStateCost::LodrStateCost(
|
||||
LodrFst* fst, const std::unordered_map<int32_t, float> &state_cost)
|
||||
: fst_(fst) {
|
||||
if (state_cost.empty()) {
|
||||
state_cost_[0] = 0.0;
|
||||
} else {
|
||||
state_cost_ = state_cost;
|
||||
}
|
||||
}
|
||||
|
||||
LodrStateCost LodrStateCost::ForwardOneStep(int32_t label) {
|
||||
std::unordered_map<int32_t, float> state_cost;
|
||||
for (const auto& [s, c] : state_cost_) {
|
||||
auto [next_states, next_costs] = fst_->GetNextStateCosts(s, label);
|
||||
for (size_t i = 0; i < next_states.size(); ++i) {
|
||||
int32_t ns = next_states[i];
|
||||
float nc = next_costs[i];
|
||||
if (state_cost.find(ns) == state_cost.end()) {
|
||||
state_cost[ns] = std::numeric_limits<float>::infinity();
|
||||
}
|
||||
state_cost[ns] = std::min(state_cost[ns], c + nc);
|
||||
}
|
||||
}
|
||||
return LodrStateCost(fst_, state_cost);
|
||||
}
|
||||
|
||||
float LodrStateCost::Score() const {
|
||||
if (state_cost_.empty()) {
|
||||
return -std::numeric_limits<float>::infinity();
|
||||
}
|
||||
auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(),
|
||||
[](const auto& a, const auto& b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
return -min_cost->second;
|
||||
}
|
||||
|
||||
float LodrStateCost::FinalScore() const {
|
||||
if (state_cost_.empty()) {
|
||||
return -std::numeric_limits<float>::infinity();
|
||||
}
|
||||
auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(),
|
||||
[](const auto& a, const auto& b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
return -(min_cost->second +
|
||||
fst_->GetFinalCost(min_cost->first));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
75
sherpa-onnx/csrc/lodr-fst.h
Normal file
75
sherpa-onnx/csrc/lodr-fst.h
Normal file
@@ -0,0 +1,75 @@
|
||||
// sherpa-onnx/csrc/lodr-fst.h
|
||||
//
|
||||
// Contains code copied from icefall/utils/ngram_lm.py
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
//
|
||||
// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs)
|
||||
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_LODR_FST_H_
|
||||
#define SHERPA_ONNX_CSRC_LODR_FST_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "kaldifst/csrc/kaldi-fst-io.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class Hypothesis;
|
||||
|
||||
class LodrFst {
|
||||
public:
|
||||
explicit LodrFst(const std::string &fst_path, int32_t backoff_id = -1);
|
||||
|
||||
std::pair<std::vector<int32_t>, std::vector<float>> GetNextStateCosts(
|
||||
int32_t state, int32_t label);
|
||||
|
||||
float GetFinalCost(int32_t state);
|
||||
|
||||
void ComputeScore(float scale, Hypothesis *hyp, int32_t offset);
|
||||
|
||||
private:
|
||||
fst::StdVectorFst YsToFst(const std::vector<int64_t> &ys, int32_t offset);
|
||||
|
||||
std::vector<std::tuple<int32_t, float>> ProcessBackoffArcs(
|
||||
int32_t state, float cost);
|
||||
|
||||
std::optional<std::tuple<int32_t, float>> GetNextStatesCostsNoBackoff(
|
||||
int32_t state, int32_t label);
|
||||
|
||||
int32_t FindBackoffId();
|
||||
|
||||
|
||||
int32_t backoff_id_ = -1;
|
||||
std::unique_ptr<fst::StdConstFst> fst_; // owned by this class
|
||||
};
|
||||
|
||||
class LodrStateCost {
|
||||
public:
|
||||
explicit LodrStateCost(
|
||||
LodrFst* fst,
|
||||
const std::unordered_map<int32_t, float> &state_cost = {});
|
||||
|
||||
LodrStateCost ForwardOneStep(int32_t label);
|
||||
|
||||
float Score() const;
|
||||
float FinalScore() const;
|
||||
|
||||
private:
|
||||
// The fst_ is not owned by this class and borrowed from the caller
|
||||
// (e.g. OnlineRnnLM).
|
||||
LodrFst* fst_;
|
||||
std::unordered_map<int32_t, float> state_cost_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_LODR_FST_H_
|
||||
@@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) {
|
||||
"Number of threads to run the neural network of LM model");
|
||||
po->Register("lm-provider", &lm_provider,
|
||||
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||
po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model.");
|
||||
po->Register("lodr-scale", &lodr_scale, "LODR scale.");
|
||||
po->Register("lodr-backoff-id", &lodr_backoff_id,
|
||||
"ID of the backoff in the LODR FST. -1 means autodetect");
|
||||
}
|
||||
|
||||
bool OfflineLMConfig::Validate() const {
|
||||
@@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!lodr_fst.empty() && !FileExists(lodr_fst)) {
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const {
|
||||
|
||||
os << "OfflineLMConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "scale=" << scale << ")";
|
||||
os << "scale=" << scale << ", ";
|
||||
os << "lodr_scale=" << lodr_scale << ", ";
|
||||
os << "lodr_fst=\"" << lodr_fst << "\", ";
|
||||
os << "lodr_backoff_id=" << lodr_backoff_id << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -19,14 +19,23 @@ struct OfflineLMConfig {
|
||||
int32_t lm_num_threads = 1;
|
||||
std::string lm_provider = "cpu";
|
||||
|
||||
// LODR
|
||||
std::string lodr_fst;
|
||||
float lodr_scale = 0.01;
|
||||
int32_t lodr_backoff_id = -1; // -1 means not set
|
||||
|
||||
OfflineLMConfig() = default;
|
||||
|
||||
OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
||||
const std::string &lm_provider)
|
||||
const std::string &lm_provider, const std::string &lodr_fst,
|
||||
float lodr_scale, int32_t lodr_backoff_id)
|
||||
: model(model),
|
||||
scale(scale),
|
||||
lm_num_threads(lm_num_threads),
|
||||
lm_provider(lm_provider) {}
|
||||
lm_provider(lm_provider),
|
||||
lodr_fst(lodr_fst),
|
||||
lodr_scale(lodr_scale),
|
||||
lodr_backoff_id(lodr_backoff_id) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/lodr-fst.h"
|
||||
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
|
||||
}
|
||||
auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
|
||||
const float *p_nll = negative_loglike.GetTensorData<float>();
|
||||
// We scale LODR scale with LM scale to replicate Icefall code
|
||||
auto lodr_scale = config_.lodr_scale * scale;
|
||||
for (auto &h : *hyps) {
|
||||
for (auto &t : h) {
|
||||
// Use -scale here since we want to change negative loglike to loglike.
|
||||
t.second.lm_log_prob = -scale * (*p_nll);
|
||||
++p_nll;
|
||||
// apply LODR to hyp score
|
||||
if (lodr_fst_ != nullptr) {
|
||||
lodr_fst_->ComputeScore(lodr_scale, &t.second, context_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,24 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/lodr-fst.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineLM {
|
||||
public:
|
||||
explicit OfflineLM(const OfflineLMConfig &config) : config_(config) {
|
||||
if (!config_.lodr_fst.empty()) {
|
||||
try {
|
||||
lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst,
|
||||
config_.lodr_backoff_id));
|
||||
} catch (const std::exception& e) {
|
||||
throw std::runtime_error("Failed to load LODR FST from: " +
|
||||
config_.lodr_fst + ". Error: " + e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual ~OfflineLM() = default;
|
||||
|
||||
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
|
||||
@@ -43,6 +55,11 @@ class OfflineLM {
|
||||
// @param hyps It is changed in-place.
|
||||
void ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps);
|
||||
|
||||
private:
|
||||
std::unique_ptr<LodrFst> lodr_fst_;
|
||||
float lodr_scale_;
|
||||
OfflineLMConfig config_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -83,11 +83,11 @@ class OfflineRnnLM::Impl {
|
||||
};
|
||||
|
||||
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
: impl_(std::make_unique<Impl>(config)), OfflineLM(config) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
: impl_(std::make_unique<Impl>(mgr, config)), OfflineLM(config) {}
|
||||
|
||||
OfflineRnnLM::~OfflineRnnLM() = default;
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) {
|
||||
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||
po->Register("lm-shallow-fusion", &shallow_fusion,
|
||||
"Boolean whether to use shallow fusion or rescore.");
|
||||
po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model.");
|
||||
po->Register("lodr-scale", &lodr_scale, "LODR scale.");
|
||||
po->Register("lodr-backoff-id", &lodr_backoff_id,
|
||||
"ID of the backoff in the LODR FST. -1 means autodetect");
|
||||
}
|
||||
|
||||
bool OnlineLMConfig::Validate() const {
|
||||
@@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!lodr_fst.empty() && !FileExists(lodr_fst)) {
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const {
|
||||
os << "OnlineLMConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "scale=" << scale << ", ";
|
||||
os << "lodr_scale=" << lodr_scale << ", ";
|
||||
os << "lodr_fst=\"" << lodr_fst << "\", ";
|
||||
os << "lodr_backoff_id=" << lodr_backoff_id << ", ";
|
||||
os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
|
||||
@@ -18,18 +18,26 @@ struct OnlineLMConfig {
|
||||
float scale = 0.5;
|
||||
int32_t lm_num_threads = 1;
|
||||
std::string lm_provider = "cpu";
|
||||
std::string lodr_fst;
|
||||
float lodr_scale = 0.01;
|
||||
int32_t lodr_backoff_id = -1; // -1 means not set
|
||||
// enable shallow fusion
|
||||
bool shallow_fusion = true;
|
||||
|
||||
OnlineLMConfig() = default;
|
||||
|
||||
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
||||
const std::string &lm_provider, bool shallow_fusion)
|
||||
const std::string &lm_provider, bool shallow_fusion,
|
||||
const std::string &lodr_fst, float lodr_scale,
|
||||
int32_t lodr_backoff_id)
|
||||
: model(model),
|
||||
scale(scale),
|
||||
lm_num_threads(lm_num_threads),
|
||||
lm_provider(lm_provider),
|
||||
shallow_fusion(shallow_fusion) {}
|
||||
shallow_fusion(shallow_fusion),
|
||||
lodr_fst(lodr_fst),
|
||||
lodr_scale(lodr_scale),
|
||||
lodr_backoff_id(lodr_backoff_id) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/lodr-fst.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
@@ -35,12 +36,27 @@ class OnlineRnnLM::Impl {
|
||||
auto init_states = GetInitStatesSF();
|
||||
hyp->nn_lm_scores.value = std::move(init_states.first);
|
||||
hyp->nn_lm_states = Convert(std::move(init_states.second));
|
||||
// if LODR enabled, we need to initialize the LODR state
|
||||
if (lodr_fst_ != nullptr) {
|
||||
hyp->lodr_state = std::make_unique<LodrStateCost>(lodr_fst_.get());
|
||||
}
|
||||
}
|
||||
|
||||
// get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob
|
||||
const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>();
|
||||
hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale;
|
||||
|
||||
// if LODR enabled, we need to update the LODR state
|
||||
if (lodr_fst_ != nullptr) {
|
||||
auto next_lodr_state = std::make_unique<LodrStateCost>(
|
||||
hyp->lodr_state->ForwardOneStep(hyp->ys.back()));
|
||||
// calculate the score of the latest token
|
||||
auto score = next_lodr_state->Score() - hyp->lodr_state->Score();
|
||||
hyp->lodr_state = std::move(next_lodr_state);
|
||||
// apply LODR to hyp score
|
||||
hyp->lm_log_prob += score * config_.lodr_scale;
|
||||
}
|
||||
|
||||
// get lm scores for next tokens given the hyp->ys[:] and save to
|
||||
// nn_lm_scores
|
||||
std::array<int64_t, 2> x_shape{1, 1};
|
||||
@@ -89,6 +105,12 @@ class OnlineRnnLM::Impl {
|
||||
const float *p_nll = out.first.GetTensorData<float>();
|
||||
h.lm_log_prob = -scale * (*p_nll);
|
||||
|
||||
// apply LODR to hyp score
|
||||
if (lodr_fst_ != nullptr) {
|
||||
// We scale LODR scale with LM scale to replicate Icefall code
|
||||
lodr_fst_->ComputeScore(config_.lodr_scale*scale, &h, context_size);
|
||||
}
|
||||
|
||||
// update NN LM states in hyp
|
||||
h.nn_lm_states = Convert(std::move(out.second));
|
||||
|
||||
@@ -154,6 +176,11 @@ class OnlineRnnLM::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id");
|
||||
|
||||
ComputeInitStates();
|
||||
|
||||
if (!config_.lodr_fst.empty()) {
|
||||
lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst,
|
||||
config_.lodr_backoff_id));
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeInitStates() {
|
||||
@@ -203,6 +230,8 @@ class OnlineRnnLM::Impl {
|
||||
int32_t rnn_num_layers_ = 2;
|
||||
int32_t rnn_hidden_size_ = 512;
|
||||
int32_t sos_id_ = 1;
|
||||
|
||||
std::unique_ptr<LodrFst> lodr_fst_;
|
||||
};
|
||||
|
||||
OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
|
||||
|
||||
@@ -13,13 +13,19 @@ namespace sherpa_onnx {
|
||||
void PybindOfflineLMConfig(py::module *m) {
|
||||
using PyClass = OfflineLMConfig;
|
||||
py::class_<PyClass>(*m, "OfflineLMConfig")
|
||||
.def(py::init<const std::string &, float, int32_t, const std::string &>(),
|
||||
.def(py::init<const std::string &, float, int32_t, const std::string &,
|
||||
const std::string &, float, int32_t>(),
|
||||
py::arg("model"), py::arg("scale") = 0.5f,
|
||||
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu")
|
||||
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
|
||||
py::arg("lodr_fst") = "", py::arg("lodr_scale") = 0.0f,
|
||||
py::arg("lodr_backoff_id") = -1)
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("scale", &PyClass::scale)
|
||||
.def_readwrite("lm_provider", &PyClass::lm_provider)
|
||||
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
|
||||
.def_readwrite("lodr_fst", &PyClass::lodr_fst)
|
||||
.def_readwrite("lodr_scale", &PyClass::lodr_scale)
|
||||
.def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) {
|
||||
using PyClass = OnlineLMConfig;
|
||||
py::class_<PyClass>(*m, "OnlineLMConfig")
|
||||
.def(py::init<const std::string &, float, int32_t,
|
||||
const std::string &, bool>(),
|
||||
const std::string &, bool, const std::string &,
|
||||
float, int>(),
|
||||
py::arg("model") = "", py::arg("scale") = 0.5f,
|
||||
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
|
||||
py::arg("shallow_fusion") = true)
|
||||
py::arg("shallow_fusion") = true, py::arg("lodr_fst") = "",
|
||||
py::arg("lodr_scale") = 0.0f, py::arg("lodr_backoff_id") = -1)
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("scale", &PyClass::scale)
|
||||
.def_readwrite("lm_provider", &PyClass::lm_provider)
|
||||
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
|
||||
.def_readwrite("shallow_fusion", &PyClass::shallow_fusion)
|
||||
.def_readwrite("lodr_fst", &PyClass::lodr_fst)
|
||||
.def_readwrite("lodr_scale", &PyClass::lodr_scale)
|
||||
.def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id)
|
||||
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -69,6 +69,8 @@ class OfflineRecognizer(object):
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
lodr_fst: str = "",
|
||||
lodr_scale: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -133,6 +135,10 @@ class OfflineRecognizer(object):
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
lodr_fst:
|
||||
Path to the LODR FST file in binary format. If empty, LODR is disabled.
|
||||
lodr_scale:
|
||||
Scale factor for LODR rescoring. Only used when lodr_fst is provided.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -173,6 +179,8 @@ class OfflineRecognizer(object):
|
||||
scale=lm_scale,
|
||||
lm_num_threads=num_threads,
|
||||
lm_provider=provider,
|
||||
lodr_fst=lodr_fst,
|
||||
lodr_scale=lodr_scale,
|
||||
)
|
||||
|
||||
recognizer_config = OfflineRecognizerConfig(
|
||||
|
||||
@@ -89,6 +89,8 @@ class OnlineRecognizer(object):
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
lodr_fst: str = "",
|
||||
lodr_scale: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -216,6 +218,10 @@ class OnlineRecognizer(object):
|
||||
"Set path for storing timing cache." TensorRT EP
|
||||
trt_dump_subgraphs: bool = False,
|
||||
"Dump optimized subgraphs for debugging." TensorRT EP
|
||||
lodr_fst:
|
||||
Path to the LODR FST file in binary format. If empty, LODR is disabled.
|
||||
lodr_scale:
|
||||
Scale factor for LODR rescoring. Only used when lodr_fst is provided.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -298,6 +304,8 @@ class OnlineRecognizer(object):
|
||||
model=lm,
|
||||
scale=lm_scale,
|
||||
shallow_fusion=lm_shallow_fusion,
|
||||
lodr_fst=lodr_fst,
|
||||
lodr_scale=lodr_scale,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
|
||||
Reference in New Issue
Block a user