re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)
Co-authored-by: xiao <shawl336@163.com>
This commit is contained in:
@@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
|
||||
recognizer_config.model_config.tokens =
|
||||
SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
if (config->model_config.tokens_buf &&
|
||||
config->model_config.tokens_buf_size > 0) {
|
||||
recognizer_config.model_config.tokens_buf = std::string(
|
||||
config->model_config.tokens_buf, config->model_config.tokens_buf_size);
|
||||
}
|
||||
|
||||
recognizer_config.model_config.num_threads =
|
||||
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
recognizer_config.model_config.provider_config.provider =
|
||||
@@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
|
||||
recognizer_config.hotwords_score =
|
||||
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
|
||||
if (config->hotwords_buf && config->hotwords_buf_size > 0) {
|
||||
recognizer_config.hotwords_buf =
|
||||
std::string(config->hotwords_buf, config->hotwords_buf_size);
|
||||
}
|
||||
|
||||
recognizer_config.blank_penalty = config->blank_penalty;
|
||||
|
||||
|
||||
@@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
|
||||
// - cjkchar+bpe
|
||||
const char *modeling_unit;
|
||||
const char *bpe_vocab;
|
||||
/// if non-null, loading the tokens from the buffered string directly in
|
||||
/// prioriy
|
||||
const char *tokens_buf;
|
||||
/// byte size excluding the tailing '\0'
|
||||
int32_t tokens_buf_size;
|
||||
} SherpaOnnxOnlineModelConfig;
|
||||
|
||||
/// It expects 16 kHz 16-bit single channel wave format.
|
||||
@@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
|
||||
const char *rule_fsts;
|
||||
const char *rule_fars;
|
||||
float blank_penalty;
|
||||
|
||||
/// if non-nullptr, loading the hotwords from the buffered string directly in
|
||||
const char *hotwords_buf;
|
||||
/// byte size excluding the tailing '\0'
|
||||
int32_t hotwords_buf_size;
|
||||
} SherpaOnnxOnlineRecognizerConfig;
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
|
||||
|
||||
@@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
|
||||
if (!tokens_buf.empty() && FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"you can not provide a tokens_buf and a tokens file: '%s', "
|
||||
"at the same time, which is confusing",
|
||||
tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tokens_buf.empty() && !FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"tokens: '%s' does not exist, you should provide "
|
||||
"either a tokens buffer or a tokens file",
|
||||
tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,11 @@ struct OnlineModelConfig {
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
|
||||
/// if tokens_buf is non-empty,
|
||||
/// the tokens will be loaded from the buffered string instead of from the
|
||||
/// ${tokens} file
|
||||
std::string tokens_buf;
|
||||
|
||||
OnlineModelConfig() = default;
|
||||
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
|
||||
const OnlineParaformerModelConfig ¶former,
|
||||
@@ -53,8 +58,7 @@ struct OnlineModelConfig {
|
||||
const OnlineNeMoCtcModelConfig &nemo_ctc,
|
||||
const ProviderConfig &provider_config,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
int32_t warm_up, bool debug,
|
||||
const std::string &model_type,
|
||||
int32_t warm_up, bool debug, const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
const std::string &bpe_vocab)
|
||||
: transducer(transducer),
|
||||
|
||||
@@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
@@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
if (!config_.hotwords_buf.empty()) {
|
||||
InitHotwordsFromBufStr();
|
||||
} else if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
|
||||
@@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
config_.blank_penalty, config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
@@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
config_.blank_penalty, config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
@@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitHotwordsFromBufStr() {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
std::istringstream iss(config_.hotwords_buf);
|
||||
if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
|
||||
@@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
const OnlineRecognizerConfig &config)
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config),
|
||||
model_(
|
||||
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
symbol_table_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
symbol_table_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
|
||||
@@ -106,6 +106,11 @@ struct OnlineRecognizerConfig {
|
||||
// If there are multiple FST archives, they are applied from left to right.
|
||||
std::string rule_fars;
|
||||
|
||||
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
||||
/// the hotwords will be loaded from the buffered string instead of from
|
||||
/// ${hotwords_file}
|
||||
std::string hotwords_buf;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(
|
||||
|
||||
@@ -20,9 +20,14 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
SymbolTable::SymbolTable(const std::string &filename) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
SymbolTable::SymbolTable(const std::string &filename, bool is_file) {
|
||||
if (is_file) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
} else {
|
||||
std::istringstream iss(filename);
|
||||
Init(iss);
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
|
||||
@@ -19,13 +19,13 @@ namespace sherpa_onnx {
|
||||
class SymbolTable {
|
||||
public:
|
||||
SymbolTable() = default;
|
||||
/// Construct a symbol table from a file.
|
||||
/// Construct a symbol table from a file or from a buffered string.
|
||||
/// Each line in the file contains two fields:
|
||||
///
|
||||
/// sym ID
|
||||
///
|
||||
/// Fields are separated by space(s).
|
||||
explicit SymbolTable(const std::string &filename);
|
||||
explicit SymbolTable(const std::string &filename, bool is_file = true);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SymbolTable(AAssetManager *mgr, const std::string &filename);
|
||||
|
||||
Reference in New Issue
Block a user