decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
This commit is contained in:
@@ -15,16 +15,31 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
hotwords->clear();
|
||||
std::vector<int32_t> tmp;
|
||||
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *ids,
|
||||
std::vector<std::string> *phrases,
|
||||
std::vector<float> *scores,
|
||||
std::vector<float> *thresholds) {
|
||||
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||
ids->clear();
|
||||
|
||||
std::vector<int32_t> tmp_ids;
|
||||
std::vector<float> tmp_scores;
|
||||
std::vector<float> tmp_thresholds;
|
||||
std::vector<std::string> tmp_phrases;
|
||||
|
||||
std::string line;
|
||||
std::string word;
|
||||
bool has_scores = false;
|
||||
bool has_thresholds = false;
|
||||
bool has_phrases = false;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
float score = 0;
|
||||
float threshold = 0;
|
||||
std::string phrase = "";
|
||||
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> syms;
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
@@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
}
|
||||
}
|
||||
if (symbol_table.contains(word)) {
|
||||
int32_t number = symbol_table[word];
|
||||
tmp.push_back(number);
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
|
||||
"the "
|
||||
"same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
switch (word[0]) {
|
||||
case ':': // boosting score for current keyword
|
||||
score = std::stof(word.substr(1));
|
||||
has_scores = true;
|
||||
break;
|
||||
case '#': // triggering threshold (probability) for current keyword
|
||||
threshold = std::stof(word.substr(1));
|
||||
has_thresholds = true;
|
||||
break;
|
||||
case '@': // the original keyword string
|
||||
phrase = word.substr(1);
|
||||
has_phrases = true;
|
||||
break;
|
||||
default:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||
"the same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
hotwords->push_back(std::move(tmp));
|
||||
ids->push_back(std::move(tmp_ids));
|
||||
tmp_scores.push_back(score);
|
||||
tmp_phrases.push_back(phrase);
|
||||
tmp_thresholds.push_back(threshold);
|
||||
}
|
||||
if (scores != nullptr) {
|
||||
if (has_scores) {
|
||||
scores->swap(tmp_scores);
|
||||
} else {
|
||||
scores->clear();
|
||||
}
|
||||
}
|
||||
if (phrases != nullptr) {
|
||||
if (has_phrases) {
|
||||
*phrases = std::move(tmp_phrases);
|
||||
} else {
|
||||
phrases->clear();
|
||||
}
|
||||
}
|
||||
if (thresholds != nullptr) {
|
||||
if (has_thresholds) {
|
||||
thresholds->swap(tmp_thresholds);
|
||||
} else {
|
||||
thresholds->clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *keywords_id,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold) {
|
||||
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||
threshold);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user