@@ -31,6 +31,7 @@
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
@@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
lm_ = OfflineLM::Create(config.lm_config);
|
||||
}
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
@@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||
}
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
|
||||
std::istringstream iss(std::string(buf.begin(), buf.end()));
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords(mgr);
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
@@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, symbol_table_, ¤t)) {
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
@@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
void InitHotwords(AAssetManager *mgr) {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
auto buf = ReadFile(mgr, config_.hotwords_file);
|
||||
|
||||
std::istringstream is(std::string(buf.begin(), buf.end()));
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
config_.hotwords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
std::unique_ptr<OfflineLM> lm_;
|
||||
|
||||
Reference in New Issue
Block a user