Add RNN LM rescore for offline ASR with modified_beam_search (#125)

This commit is contained in:
Fangjun Kuang
2023-04-23 17:15:18 +08:00
committed by GitHub
parent d49a597431
commit 86017f9833
32 changed files with 842 additions and 52 deletions

View File

@@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-model.h"
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
@@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") {
SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
exit(-1);
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(config.lm_config);
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
@@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
SymbolTable symbol_table_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;
};
} // namespace sherpa_onnx