Support customize scores for hotwords (#926)
* Support customize scores for hotwords * Skip blank lines
This commit is contained in:
@@ -61,10 +61,9 @@ class ContextGraph {
|
||||
}
|
||||
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score, const std::vector<float> &scores = {},
|
||||
const std::vector<std::string> &phrases = {})
|
||||
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
|
||||
std::vector<float>()) {}
|
||||
float context_score, const std::vector<float> &scores = {})
|
||||
: ContextGraph(token_ids, context_score, 0.0f, scores,
|
||||
std::vector<std::string>(), std::vector<float>()) {}
|
||||
|
||||
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id,
|
||||
|
||||
@@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
std::vector<float> current_scores;
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
bpe_encoder_.get(), ¤t, ¤t_scores)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
|
||||
int32_t num_default_hws = hotwords_.size();
|
||||
int32_t num_hws = current.size();
|
||||
|
||||
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
|
||||
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(current, config_.hotwords_score);
|
||||
if (!current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else if (!current_scores.empty() && boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_default_hws,
|
||||
config_.hotwords_score);
|
||||
} else if (current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_hws,
|
||||
config_.hotwords_score);
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
auto context_graph = std::make_shared<ContextGraph>(
|
||||
current, config_.hotwords_score, current_scores);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
@@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
std::vector<float> boost_scores_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
|
||||
@@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
std::vector<float> current_scores;
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
bpe_encoder_.get(), ¤t, ¤t_scores)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
|
||||
int32_t num_default_hws = hotwords_.size();
|
||||
int32_t num_hws = current.size();
|
||||
|
||||
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(current, config_.hotwords_score);
|
||||
|
||||
if (!current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else if (!current_scores.empty() && boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_default_hws,
|
||||
config_.hotwords_score);
|
||||
} else if (current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_hws,
|
||||
config_.hotwords_score);
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
auto context_graph = std::make_shared<ContextGraph>(
|
||||
current, config_.hotwords_score, current_scores);
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
@@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
std::vector<float> boost_scores_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
|
||||
@@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
|
||||
std::string text = "世界人民大团结\n中国 V S 美国";
|
||||
std::string text =
|
||||
"世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
|
||||
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
EXPECT_EQ(scores.size(), 0);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bpe) {
|
||||
@@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) {
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "HELLO WORLD\nI LOVE YOU";
|
||||
std::string text = "HELLO WORLD\nI LOVE YOU :2.0";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
auto r =
|
||||
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{22, 58, 24, 425}, {19, 370, 47}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
std::vector<float> expected_scores({0, 2.0});
|
||||
EXPECT_EQ(scores, expected_scores);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
|
||||
@@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
|
||||
std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r =
|
||||
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
|
||||
auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(),
|
||||
&ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{1368, 1392, 557, 680, 275, 178, 475},
|
||||
{685, 736, 275, 178, 179, 921, 736}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
std::vector<float> expected_scores({1.5, 0.5});
|
||||
EXPECT_EQ(scores, expected_scores);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bbpe) {
|
||||
@@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "频繁\n李鞑靼";
|
||||
std::string text = "频繁 :1.0\n李鞑靼";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
auto r =
|
||||
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
std::vector<float> expected_scores({1.0, 0});
|
||||
EXPECT_EQ(scores, expected_scores);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines,
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
std::vector<std::vector<int32_t>> *hotwords,
|
||||
std::vector<float> *boost_scores) {
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
std::string word;
|
||||
@@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
break;
|
||||
}
|
||||
}
|
||||
phrase = oss.str().substr(1);
|
||||
phrase = oss.str();
|
||||
if (phrase.empty()) {
|
||||
continue;
|
||||
} else {
|
||||
phrase = phrase.substr(1);
|
||||
}
|
||||
std::istringstream piss(phrase);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
@@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
}
|
||||
lines.push_back(oss.str());
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
|
||||
@@ -29,7 +29,8 @@ namespace sherpa_onnx {
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder_,
|
||||
std::vector<std::vector<int32_t>> *hotwords_id);
|
||||
std::vector<std::vector<int32_t>> *hotwords_id,
|
||||
std::vector<float> *boost_scores);
|
||||
|
||||
/* Encode the keywords in an input stream to be tokens ids.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user