re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)

Co-authored-by: xiao <shawl336@163.com>
This commit is contained in:
lxiao336
2024-09-13 09:58:17 +08:00
committed by GitHub
parent 6b6e7635ed
commit 65cfa7548a
12 changed files with 414 additions and 16 deletions

View File

@@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
if (config->model_config.tokens_buf &&
config->model_config.tokens_buf_size > 0) {
recognizer_config.model_config.tokens_buf = std::string(
config->model_config.tokens_buf, config->model_config.tokens_buf_size);
}
recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider_config.provider =
@@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
if (config->hotwords_buf && config->hotwords_buf_size > 0) {
recognizer_config.hotwords_buf =
std::string(config->hotwords_buf, config->hotwords_buf_size);
}
recognizer_config.blank_penalty = config->blank_penalty;

View File

@@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
// - cjkchar+bpe
const char *modeling_unit;
const char *bpe_vocab;
/// if non-null, loading the tokens from the buffered string directly in
/// prioriy
const char *tokens_buf;
/// byte size excluding the tailing '\0'
int32_t tokens_buf_size;
} SherpaOnnxOnlineModelConfig;
/// It expects 16 kHz 16-bit single channel wave format.
@@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
const char *rule_fsts;
const char *rule_fars;
float blank_penalty;
/// if non-nullptr, loading the hotwords from the buffered string directly in
const char *hotwords_buf;
/// byte size excluding the tailing '\0'
int32_t hotwords_buf_size;
} SherpaOnnxOnlineRecognizerConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {

View File

@@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const {
return false;
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
if (!tokens_buf.empty() && FileExists(tokens)) {
SHERPA_ONNX_LOGE(
"you can not provide a tokens_buf and a tokens file: '%s', "
"at the same time, which is confusing",
tokens.c_str());
return false;
}
if (tokens_buf.empty() && !FileExists(tokens)) {
SHERPA_ONNX_LOGE(
"tokens: '%s' does not exist, you should provide "
"either a tokens buffer or a tokens file",
tokens.c_str());
return false;
}

View File

@@ -45,6 +45,11 @@ struct OnlineModelConfig {
std::string modeling_unit = "cjkchar";
std::string bpe_vocab;
/// if tokens_buf is non-empty,
/// the tokens will be loaded from the buffered string instead of from the
/// ${tokens} file
std::string tokens_buf;
OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer,
@@ -53,8 +58,7 @@ struct OnlineModelConfig {
const OnlineNeMoCtcModelConfig &nemo_ctc,
const ProviderConfig &provider_config,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug,
const std::string &model_type,
int32_t warm_up, bool debug, const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
: transducer(transducer),

View File

@@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
: OnlineRecognizerImpl(config),
config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config.model_config.tokens_buf.empty()) {
sym_ = SymbolTable(config.model_config.tokens_buf, false);
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_ = SymbolTable(config.model_config.tokens, true);
}
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
@@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
if (!config_.hotwords_buf.empty()) {
InitHotwordsFromBufStr();
} else if (!config_.hotwords_file.empty()) {
InitHotwords();
}
@@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);
config_.blank_penalty, config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
@@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);
config_.blank_penalty, config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
@@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
void InitHotwordsFromBufStr() {
// each line in hotwords_file contains space-separated words
std::istringstream iss(config_.hotwords_buf);
if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ = std::make_shared<ContextGraph>(
hotwords_, config_.hotwords_score, boost_scores_);
}
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();

View File

@@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
const OnlineRecognizerConfig &config)
: OnlineRecognizerImpl(config),
config_(config),
symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
if (!config.model_config.tokens_buf.empty()) {
symbol_table_ = SymbolTable(config.model_config.tokens_buf, false);
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
symbol_table_ = SymbolTable(config.model_config.tokens, true);
}
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);

View File

@@ -106,6 +106,11 @@ struct OnlineRecognizerConfig {
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;
/// used only for modified_beam_search, if hotwords_buf is non-empty,
/// the hotwords will be loaded from the buffered string instead of from
/// ${hotwords_file}
std::string hotwords_buf;
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(

View File

@@ -20,9 +20,14 @@
namespace sherpa_onnx {
SymbolTable::SymbolTable(const std::string &filename) {
std::ifstream is(filename);
Init(is);
SymbolTable::SymbolTable(const std::string &filename, bool is_file) {
if (is_file) {
std::ifstream is(filename);
Init(is);
} else {
std::istringstream iss(filename);
Init(iss);
}
}
#if __ANDROID_API__ >= 9

View File

@@ -19,13 +19,13 @@ namespace sherpa_onnx {
class SymbolTable {
public:
SymbolTable() = default;
/// Construct a symbol table from a file.
/// Construct a symbol table from a file or from a buffered string.
/// Each line in the file contains two fields:
///
/// sym ID
///
/// Fields are separated by space(s).
explicit SymbolTable(const std::string &filename);
explicit SymbolTable(const std::string &filename, bool is_file = true);
#if __ANDROID_API__ >= 9
SymbolTable(AAssetManager *mgr, const std::string &filename);