Support customize scores for hotwords (#926)
* Support customize scores for hotwords * Skip blank lines
This commit is contained in:
@@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
|
||||
std::string text = "世界人民大团结\n中国 V S 美国";
|
||||
std::string text =
|
||||
"世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
|
||||
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
EXPECT_EQ(scores.size(), 0);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bpe) {
|
||||
@@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) {
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "HELLO WORLD\nI LOVE YOU";
|
||||
std::string text = "HELLO WORLD\nI LOVE YOU :2.0";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
auto r =
|
||||
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{22, 58, 24, 425}, {19, 370, 47}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
std::vector<float> expected_scores({0, 2.0});
|
||||
EXPECT_EQ(scores, expected_scores);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
|
||||
@@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
|
||||
std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r =
|
||||
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
|
||||
auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(),
|
||||
&ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{1368, 1392, 557, 680, 275, 178, 475},
|
||||
{685, 736, 275, 178, 179, 921, 736}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
std::vector<float> expected_scores({1.5, 0.5});
|
||||
EXPECT_EQ(scores, expected_scores);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bbpe) {
|
||||
@@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "频繁\n李鞑靼";
|
||||
std::string text = "频繁 :1.0\n李鞑靼";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
std::vector<float> scores;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
auto r =
|
||||
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
|
||||
std::vector<float> expected_scores({1.0, 0});
|
||||
EXPECT_EQ(scores, expected_scores);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user