Encode hotwords in C++ side (#828)

* Encode hotwords in C++ side
This commit is contained in:
Wei Kang
2024-05-20 19:41:36 +08:00
committed by GitHub
parent 8af2af8466
commit b012b78ceb
43 changed files with 714 additions and 102 deletions

View File

@@ -165,6 +165,7 @@ endif()
target_link_libraries(sherpa-onnx-core
kaldi-native-fbank-core
kaldi-decoder-core
ssentencepiece_core
)
if(SHERPA_ONNX_ENABLE_GPU)
@@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
pad-sequence-test.cc
slice-test.cc
stack-test.cc
text2token-test.cc
transpose-test.cc
unbind-test.cc
utfcpp-test.cc

View File

@@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) {
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc"
"All other values lead to loading the model twice.");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");
po->Register("bpe-vocab", &bpe_vocab,
"The vocabulary generated by google's sentencepiece program. "
"It is a file has two columns, one is the token, the other is "
"the log probability, you can get it from the directory where "
"your bpe model is generated. Only used when hotwords provided "
"and the modeling unit is bpe or cjkchar+bpe");
}
bool OfflineModelConfig::Validate() const {
@@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const {
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
return false;
}
}
if (!paraformer.model.empty()) {
return paraformer.Validate();
}
@@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const {
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")";
return os.str();
}

View File

@@ -41,6 +41,9 @@ struct OfflineModelConfig {
// All other values are invalid and lead to loading the model twice.
std::string model_type;
std::string modeling_unit = "cjkchar";
std::string bpe_vocab;
OfflineModelConfig() = default;
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer,
@@ -50,7 +53,9 @@ struct OfflineModelConfig {
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
const std::string &provider, const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
: transducer(transducer),
paraformer(paraformer),
nemo_ctc(nemo_ctc),
@@ -62,7 +67,9 @@ struct OfflineModelConfig {
num_threads(num_threads),
debug(debug),
provider(provider),
model_type(model_type) {}
model_type(model_type),
modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {}
void Register(ParseOptions *po);
bool Validate() const;

View File

@@ -31,6 +31,7 @@
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
@@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
@@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
lm_ = OfflineLM::Create(config.lm_config);
}
if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.blank_penalty);
@@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
lm_ = OfflineLM::Create(mgr, config.lm_config);
}
if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
if (!config_.hotwords_file.empty()) {
InitHotwords(mgr);
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.blank_penalty);
@@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, symbol_table_, &current)) {
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
@@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
#if __ANDROID_API__ >= 9
void InitHotwords(AAssetManager *mgr) {
// each line in hotwords_file contains space-separated words
auto buf = ReadFile(mgr, config_.hotwords_file);
std::istringstream is(std::string(buf.begin(), buf.end()));
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
#endif
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;

View File

@@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "

View File

@@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");
po->Register("bpe-vocab", &bpe_vocab,
"The vocabulary generated by google's sentencepiece program. "
"It is a file has two columns, one is the token, the other is "
"the log probability, you can get it from the directory where "
"your bpe model is generated. Only used when hotwords provided "
"and the modeling unit is bpe or cjkchar+bpe");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, "
@@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const {
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
return false;
}
}
if (!paraformer.encoder.empty()) {
return paraformer.Validate();
}
@@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const {
os << "warm_up=" << warm_up << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")";
return os.str();
}

View File

@@ -37,6 +37,13 @@ struct OnlineModelConfig {
// All other values are invalid and lead to loading the model twice.
std::string model_type;
// Valid values:
// - cjkchar
// - bpe
// - cjkchar+bpe
std::string modeling_unit = "cjkchar";
std::string bpe_vocab;
OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer,
@@ -45,7 +52,9 @@ struct OnlineModelConfig {
const OnlineNeMoCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider,
const std::string &model_type)
const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
: transducer(transducer),
paraformer(paraformer),
wenet_ctc(wenet_ctc),
@@ -56,7 +65,9 @@ struct OnlineModelConfig {
warm_up(warm_up),
debug(debug),
provider(provider),
model_type(model_type) {}
model_type(model_type),
modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {}
void Register(ParseOptions *po);
bool Validate() const;

View File

@@ -15,8 +15,6 @@
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
@@ -33,6 +31,7 @@
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
@@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_->SetFeatureDim(config.feat_config.feature_dim);
if (config.decoding_method == "modified_beam_search") {
if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
@@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
if (!config_.hotwords_file.empty()) {
InitHotwords(mgr);
}
@@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, sym_, &current)) {
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
@@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
@@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto buf = ReadFile(mgr, config_.hotwords_file);
std::istrstream is(buf.data(), buf.size());
std::istringstream is(std::string(buf.begin(), buf.end()));
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
@@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
@@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;

View File

@@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec,
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": "
<< "\"" << text << "\""
<< ", ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
@@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"Used only when decoding_method is modified_beam_search");
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");

View File

@@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) {
std::string sym;
int32_t id;
while (is >> sym >> id) {
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}
// for byte-level BPE
// id 0 is blank, id 1 is sos/eos, id 2 is unk
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os;
os << std::hex << std::uppercase << (id - 3);
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
uint8_t i = id - 3;
sym = std::string(&i, &i + 1);
}
}
assert(!sym.empty());
// for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20,
// there is a conflict between the real byte 0x20 and ▁, so we disable
// the following check.
//
// Note: Only id2sym_ matters as we use it to convert ID to symbols.
#if 0
// we disable the test here since for some multi-lingual BPE models
// from NeMo, the same symbol can appear multiple times with different IDs.
@@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const {
return os.str();
}
const std::string &SymbolTable::operator[](int32_t id) const {
return id2sym_.at(id);
const std::string SymbolTable::operator[](int32_t id) const {
std::string sym = id2sym_.at(id);
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}
// for byte-level BPE
// id 0 is blank, id 1 is sos/eos, id 2 is unk
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os;
os << std::hex << std::uppercase << (id - 3);
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
uint8_t i = id - 3;
sym = std::string(&i, &i + 1);
}
}
return sym;
}
int32_t SymbolTable::operator[](const std::string &sym) const {

View File

@@ -35,7 +35,7 @@ class SymbolTable {
std::string ToString() const;
/// Return the symbol corresponding to the given ID.
const std::string &operator[](int32_t id) const;
const std::string operator[](int32_t id) const;
/// Return the ID corresponding to the given symbol.
int32_t operator[](const std::string &sym) const;

View File

@@ -0,0 +1,152 @@
// sherpa-onnx/csrc/text2token-test.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <fstream>
#include <sstream>
#include <string>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
// Please refer to
// https://github.com/pkufool/sherpa-test-data
// to download test data for testing
static const char dir[] = "/tmp/sherpa-test-data";
TEST(TEXT2TOKEN, TEST_cjkchar) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_cn.txt";
std::string tokens = oss.str();
if (!std::ifstream(tokens).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_cjkchar()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
std::string text = "世界人民大团结\n中国 V S 美国";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
EXPECT_EQ(ids, expected_ids);
}
TEST(TEXT2TOKEN, TEST_bpe) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_en.txt";
std::string tokens = oss.str();
oss.clear();
oss.str("");
oss << dir << "/text2token/bpe_en.vocab";
std::string bpe = oss.str();
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_bpe()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
std::string text = "HELLO WORLD\nI LOVE YOU";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{22, 58, 24, 425}, {19, 370, 47}});
EXPECT_EQ(ids, expected_ids);
}
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_mix.txt";
std::string tokens = oss.str();
oss.clear();
oss.str("");
oss << dir << "/text2token/bpe_mix.vocab";
std::string bpe = oss.str();
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_cjkchar_bpe()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r =
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{1368, 1392, 557, 680, 275, 178, 475},
{685, 736, 275, 178, 179, 921, 736}});
EXPECT_EQ(ids, expected_ids);
}
TEST(TEXT2TOKEN, TEST_bbpe) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_bbpe.txt";
std::string tokens = oss.str();
oss.clear();
oss.str("");
oss << dir << "/text2token/bbpe.vocab";
std::string bpe = oss.str();
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_bbpe()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
std::string text = "频繁\n李鞑靼";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
EXPECT_EQ(ids, expected_ids);
}
} // namespace sherpa_onnx

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/utils.h"
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
@@ -12,15 +13,16 @@
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
static bool EncodeBase(const std::vector<std::string> &lines,
const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *ids,
std::vector<std::string> *phrases,
std::vector<float> *scores,
std::vector<float> *thresholds) {
SHERPA_ONNX_CHECK(ids != nullptr);
ids->clear();
std::vector<int32_t> tmp_ids;
@@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
bool has_scores = false;
bool has_thresholds = false;
bool has_phrases = false;
bool has_oov = false;
while (std::getline(is, line)) {
for (const auto &line : lines) {
float score = 0;
float threshold = 0;
std::string phrase = "";
std::istringstream iss(line);
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
word = word.replace(0, 3, " ");
}
}
if (symbol_table.Contains(word)) {
int32_t id = symbol_table[word];
tmp_ids.push_back(id);
@@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
"Cannot find ID for token %s at line: %s. (Hint: words on "
"the same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
has_oov = true;
break;
}
}
}
@@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
thresholds->clear();
}
}
return true;
return !has_oov;
}
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const ssentencepiece::Ssentencepiece *bpe_encoder,
std::vector<std::vector<int32_t>> *hotwords) {
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
std::vector<std::string> lines;
std::string line;
std::string word;
while (std::getline(is, line)) {
std::string score;
std::string phrase;
std::ostringstream oss;
std::istringstream iss(line);
while (iss >> word) {
switch (word[0]) {
case ':': // boosting score for current keyword
score = word;
break;
default:
if (!score.empty()) {
SHERPA_ONNX_LOGE(
"Boosting score should be put after the words/phrase, given "
"%s.",
line.c_str());
return false;
}
oss << " " << word;
break;
}
}
phrase = oss.str().substr(1);
std::istringstream piss(phrase);
oss.clear();
oss.str("");
while (piss >> word) {
if (modeling_unit == "cjkchar") {
for (const auto &w : SplitUtf8(word)) {
oss << " " << w;
}
} else if (modeling_unit == "bpe") {
std::vector<std::string> bpes;
bpe_encoder->Encode(word, &bpes);
for (const auto &bpe : bpes) {
oss << " " << bpe;
}
} else {
if (modeling_unit != "cjkchar+bpe") {
SHERPA_ONNX_LOGE(
"modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, "
"given "
"%s",
modeling_unit.c_str());
exit(-1);
}
for (const auto &w : SplitUtf8(word)) {
if (isalpha(w[0])) {
std::vector<std::string> bpes;
bpe_encoder->Encode(w, &bpes);
for (const auto &bpe : bpes) {
oss << " " << bpe;
}
} else {
oss << " " << w;
}
}
}
}
std::string encoded_phrase = oss.str().substr(1);
oss.clear();
oss.str("");
oss << encoded_phrase;
if (!score.empty()) {
oss << " " << score;
}
lines.push_back(oss.str());
}
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
}
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
@@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold) {
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
std::vector<std::string> lines;
std::string line;
while (std::getline(is, line)) {
lines.push_back(line);
}
return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores,
threshold);
}

View File

@@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/symbol-table.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
@@ -25,7 +26,9 @@ namespace sherpa_onnx {
* @return If all the symbols from ``is`` are in the symbol_table, returns true
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const ssentencepiece::Ssentencepiece *bpe_encoder_,
std::vector<std::vector<int32_t>> *hotwords_id);
/* Encode the keywords in an input stream to be tokens ids.