Allow more online models to load tokens file from the memory (#1352)
Co-authored-by: xiao <shawl336@6163.com>
This commit is contained in:
@@ -667,6 +667,12 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
|
||||
spotter_config.model_config.tokens =
|
||||
SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
if (config->model_config.tokens_buf &&
|
||||
config->model_config.tokens_buf_size > 0) {
|
||||
spotter_config.model_config.tokens_buf = std::string(
|
||||
config->model_config.tokens_buf, config->model_config.tokens_buf_size);
|
||||
}
|
||||
|
||||
spotter_config.model_config.num_threads =
|
||||
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
spotter_config.model_config.provider_config.provider =
|
||||
@@ -691,6 +697,10 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
SHERPA_ONNX_OR(config->keywords_threshold, 0.25);
|
||||
|
||||
spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, "");
|
||||
if (config->keywords_buf && config->keywords_buf_size > 0) {
|
||||
spotter_config.keywords_buf =
|
||||
std::string(config->keywords_buf, config->keywords_buf_size);
|
||||
}
|
||||
|
||||
if (config->model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
|
||||
|
||||
@@ -88,8 +88,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
|
||||
// - cjkchar+bpe
|
||||
const char *modeling_unit;
|
||||
const char *bpe_vocab;
|
||||
/// if non-null, loading the tokens from the buffered string directly in
|
||||
/// prioriy
|
||||
/// if non-null, loading the tokens from the buffer instead of from the
|
||||
/// "tokens" file
|
||||
const char *tokens_buf;
|
||||
/// byte size excluding the trailing '\0'
|
||||
int32_t tokens_buf_size;
|
||||
@@ -637,6 +637,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
|
||||
float keywords_score;
|
||||
float keywords_threshold;
|
||||
const char *keywords_file;
|
||||
/// if non-null, loading the keywords from the buffer instead of from the
|
||||
/// keywords_file
|
||||
const char *keywords_buf;
|
||||
/// byte size excluding the trailing '\0'
|
||||
int32_t keywords_buf_size;
|
||||
} SherpaOnnxKeywordSpotterConfig;
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
|
||||
|
||||
@@ -66,15 +66,25 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
public:
|
||||
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens) {
|
||||
model_(OnlineTransducerModel::Create(config.model_config)) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
InitKeywords();
|
||||
if (config.keywords_buf.empty()) {
|
||||
InitKeywords();
|
||||
} else {
|
||||
InitKeywordsFromBufStr();
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||
@@ -305,6 +315,12 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitKeywordsFromBufStr() {
|
||||
// keywords_buf's content is supposed to be same as the keywords_file's
|
||||
std::istringstream is(config_.keywords_buf);
|
||||
InitKeywords(is);
|
||||
}
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1);
|
||||
|
||||
@@ -89,8 +89,17 @@ void KeywordSpotterConfig::Register(ParseOptions *po) {
|
||||
}
|
||||
|
||||
bool KeywordSpotterConfig::Validate() const {
|
||||
if (keywords_file.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --keywords-file.");
|
||||
if (!keywords_file.empty() && !keywords_buf.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"you can not provide a keywords_buf and a keywords file: '%s', "
|
||||
"at the same time, which is confusing",
|
||||
keywords_file.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (keywords_file.empty() && keywords_buf.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please provide either a keywords-file or the keywords-buf");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -99,7 +108,7 @@ bool KeywordSpotterConfig::Validate() const {
|
||||
// keywords file will be packaged into the sherpa-onnx-wasm-kws-main.data file
|
||||
// Solution: take keyword_file variable is directly
|
||||
// parsed as a string of keywords
|
||||
if (!std::ifstream(keywords_file.c_str()).good()) {
|
||||
if (keywords_buf.empty() && !std::ifstream(keywords_file.c_str()).good()) {
|
||||
SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.",
|
||||
keywords_file.c_str());
|
||||
return false;
|
||||
|
||||
@@ -69,6 +69,11 @@ struct KeywordSpotterConfig {
|
||||
|
||||
std::string keywords_file;
|
||||
|
||||
/// if keywords_buf is non-empty,
|
||||
/// the keywords will be loaded from the buffer instead of from the
|
||||
/// "keywrods_file"
|
||||
std::string keywords_buf;
|
||||
|
||||
KeywordSpotterConfig() = default;
|
||||
|
||||
KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
|
||||
|
||||
@@ -46,8 +46,8 @@ struct OnlineModelConfig {
|
||||
std::string bpe_vocab;
|
||||
|
||||
/// if tokens_buf is non-empty,
|
||||
/// the tokens will be loaded from the buffered string instead of from the
|
||||
/// ${tokens} file
|
||||
/// the tokens will be loaded from the buffer instead of from the
|
||||
/// "tokens" file
|
||||
std::string tokens_buf;
|
||||
|
||||
OnlineModelConfig() = default;
|
||||
|
||||
@@ -71,8 +71,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
model_(OnlineCtcModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (!config.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
|
||||
@@ -99,8 +99,14 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
model_(config.model_config),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (config.decoding_method != "greedy_search") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported decoding method: %s. Support only greedy_search at "
|
||||
|
||||
@@ -107,8 +107,8 @@ struct OnlineRecognizerConfig {
|
||||
std::string rule_fars;
|
||||
|
||||
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
||||
/// the hotwords will be loaded from the buffered string instead of from
|
||||
/// ${hotwords_file}
|
||||
/// the hotwords will be loaded from the buffered string instead of from the
|
||||
/// "hotwords_file"
|
||||
std::string hotwords_buf;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
Reference in New Issue
Block a user