@@ -165,6 +165,7 @@ endif()
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
kaldi-native-fbank-core
|
||||
kaldi-decoder-core
|
||||
ssentencepiece_core
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_GPU)
|
||||
@@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
pad-sequence-test.cc
|
||||
slice-test.cc
|
||||
stack-test.cc
|
||||
text2token-test.cc
|
||||
transpose-test.cc
|
||||
unbind-test.cc
|
||||
utfcpp-test.cc
|
||||
|
||||
@@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
||||
"tdnn, zipformer2_ctc"
|
||||
"All other values lead to loading the model twice.");
|
||||
po->Register("modeling-unit", &modeling_unit,
|
||||
"The modeling unit of the model, commonly used units are bpe, "
|
||||
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
|
||||
"hotwords are provided, we need it to encode the hotwords into "
|
||||
"token sequence.");
|
||||
po->Register("bpe-vocab", &bpe_vocab,
|
||||
"The vocabulary generated by google's sentencepiece program. "
|
||||
"It is a file has two columns, one is the token, the other is "
|
||||
"the log probability, you can get it from the directory where "
|
||||
"your bpe model is generated. Only used when hotwords provided "
|
||||
"and the modeling unit is bpe or cjkchar+bpe");
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
@@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!modeling_unit.empty() &&
|
||||
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!paraformer.model.empty()) {
|
||||
return paraformer.Validate();
|
||||
}
|
||||
@@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
os << "model_type=\"" << model_type << "\", ";
|
||||
os << "modeling_unit=\"" << modeling_unit << "\", ";
|
||||
os << "bpe_vocab=\"" << bpe_vocab << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ struct OfflineModelConfig {
|
||||
// All other values are invalid and lead to loading the model twice.
|
||||
std::string model_type;
|
||||
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
@@ -50,7 +53,9 @@ struct OfflineModelConfig {
|
||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type)
|
||||
const std::string &provider, const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
const std::string &bpe_vocab)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
@@ -62,7 +67,9 @@ struct OfflineModelConfig {
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
model_type(model_type),
|
||||
modeling_unit(modeling_unit),
|
||||
bpe_vocab(bpe_vocab) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
@@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
lm_ = OfflineLM::Create(config.lm_config);
|
||||
}
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
@@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||
}
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
|
||||
std::istringstream iss(std::string(buf.begin(), buf.end()));
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords(mgr);
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
@@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, symbol_table_, ¤t)) {
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
@@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
void InitHotwords(AAssetManager *mgr) {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
auto buf = ReadFile(mgr, config_.hotwords_file);
|
||||
|
||||
std::istringstream is(std::string(buf.begin(), buf.end()));
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
config_.hotwords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
std::unique_ptr<OfflineLM> lm_;
|
||||
|
||||
@@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
"The file containing hotwords, one words/phrases per line, For example: "
|
||||
"HELLO WORLD"
|
||||
"你好世界");
|
||||
|
||||
po->Register("hotwords-score", &hotwords_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
|
||||
@@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
|
||||
po->Register("modeling-unit", &modeling_unit,
|
||||
"The modeling unit of the model, commonly used units are bpe, "
|
||||
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
|
||||
"hotwords are provided, we need it to encode the hotwords into "
|
||||
"token sequence.");
|
||||
|
||||
po->Register("bpe-vocab", &bpe_vocab,
|
||||
"The vocabulary generated by google's sentencepiece program. "
|
||||
"It is a file has two columns, one is the token, the other is "
|
||||
"the log probability, you can get it from the directory where "
|
||||
"your bpe model is generated. Only used when hotwords provided "
|
||||
"and the modeling unit is bpe or cjkchar+bpe");
|
||||
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: conformer, lstm, zipformer, zipformer2, "
|
||||
@@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!modeling_unit.empty() &&
|
||||
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!paraformer.encoder.empty()) {
|
||||
return paraformer.Validate();
|
||||
}
|
||||
@@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const {
|
||||
os << "warm_up=" << warm_up << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
os << "model_type=\"" << model_type << "\", ";
|
||||
os << "modeling_unit=\"" << modeling_unit << "\", ";
|
||||
os << "bpe_vocab=\"" << bpe_vocab << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -37,6 +37,13 @@ struct OnlineModelConfig {
|
||||
// All other values are invalid and lead to loading the model twice.
|
||||
std::string model_type;
|
||||
|
||||
// Valid values:
|
||||
// - cjkchar
|
||||
// - bpe
|
||||
// - cjkchar+bpe
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
|
||||
OnlineModelConfig() = default;
|
||||
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
|
||||
const OnlineParaformerModelConfig ¶former,
|
||||
@@ -45,7 +52,9 @@ struct OnlineModelConfig {
|
||||
const OnlineNeMoCtcModelConfig &nemo_ctc,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
int32_t warm_up, bool debug, const std::string &provider,
|
||||
const std::string &model_type)
|
||||
const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
const std::string &bpe_vocab)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
wenet_ctc(wenet_ctc),
|
||||
@@ -56,7 +65,9 @@ struct OnlineModelConfig {
|
||||
warm_up(warm_up),
|
||||
debug(debug),
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
model_type(model_type),
|
||||
modeling_unit(modeling_unit),
|
||||
bpe_vocab(bpe_vocab) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
@@ -33,6 +31,7 @@
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
@@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
|
||||
std::istringstream iss(std::string(buf.begin(), buf.end()));
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords(mgr);
|
||||
}
|
||||
@@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, sym_, ¤t)) {
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
@@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, sym_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
@@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
auto buf = ReadFile(mgr, config_.hotwords_file);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
std::istringstream is(std::string(buf.begin(), buf.end()));
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
@@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, sym_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
@@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
OnlineRecognizerConfig config_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineLM> lm_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
|
||||
@@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec,
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{ ";
|
||||
os << "\"text\": "
|
||||
<< "\"" << text << "\""
|
||||
<< ", ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
@@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
"The file containing hotwords, one words/phrases per line, For example: "
|
||||
"HELLO WORLD"
|
||||
"你好世界");
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
|
||||
@@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) {
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (is >> sym >> id) {
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::uppercase << (id - 3);
|
||||
|
||||
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
|
||||
uint8_t i = id - 3;
|
||||
sym = std::string(&i, &i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
assert(!sym.empty());
|
||||
|
||||
// for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20,
|
||||
// there is a conflict between the real byte 0x20 and ▁, so we disable
|
||||
// the following check.
|
||||
//
|
||||
// Note: Only id2sym_ matters as we use it to convert ID to symbols.
|
||||
#if 0
|
||||
// we disable the test here since for some multi-lingual BPE models
|
||||
// from NeMo, the same symbol can appear multiple times with different IDs.
|
||||
@@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
const std::string &SymbolTable::operator[](int32_t id) const {
|
||||
return id2sym_.at(id);
|
||||
const std::string SymbolTable::operator[](int32_t id) const {
|
||||
std::string sym = id2sym_.at(id);
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::uppercase << (id - 3);
|
||||
|
||||
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
|
||||
uint8_t i = id - 3;
|
||||
sym = std::string(&i, &i + 1);
|
||||
}
|
||||
}
|
||||
return sym;
|
||||
}
|
||||
|
||||
int32_t SymbolTable::operator[](const std::string &sym) const {
|
||||
|
||||
@@ -35,7 +35,7 @@ class SymbolTable {
|
||||
std::string ToString() const;
|
||||
|
||||
/// Return the symbol corresponding to the given ID.
|
||||
const std::string &operator[](int32_t id) const;
|
||||
const std::string operator[](int32_t id) const;
|
||||
/// Return the ID corresponding to the given symbol.
|
||||
int32_t operator[](const std::string &sym) const;
|
||||
|
||||
|
||||
152
sherpa-onnx/csrc/text2token-test.cc
Normal file
152
sherpa-onnx/csrc/text2token-test.cc
Normal file
@@ -0,0 +1,152 @@
|
||||
// sherpa-onnx/csrc/text2token-test.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// Please refer to
|
||||
// https://github.com/pkufool/sherpa-test-data
|
||||
// to download test data for testing
|
||||
static const char dir[] = "/tmp/sherpa-test-data";
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_cjkchar) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_cn.txt";
|
||||
|
||||
std::string tokens = oss.str();
|
||||
|
||||
if (!std::ifstream(tokens).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_cjkchar()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
|
||||
std::string text = "世界人民大团结\n中国 V S 美国";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bpe) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_en.txt";
|
||||
std::string tokens = oss.str();
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << dir << "/text2token/bpe_en.vocab";
|
||||
std::string bpe = oss.str();
|
||||
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_bpe()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "HELLO WORLD\nI LOVE YOU";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{22, 58, 24, 425}, {19, 370, 47}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_mix.txt";
|
||||
std::string tokens = oss.str();
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << dir << "/text2token/bpe_mix.vocab";
|
||||
std::string bpe = oss.str();
|
||||
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_cjkchar_bpe()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r =
|
||||
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{1368, 1392, 557, 680, 275, 178, 475},
|
||||
{685, 736, 275, 178, 179, 921, 736}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bbpe) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_bbpe.txt";
|
||||
std::string tokens = oss.str();
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << dir << "/text2token/bbpe.vocab";
|
||||
std::string bpe = oss.str();
|
||||
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_bbpe()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "频繁\n李鞑靼";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -12,15 +13,16 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
static bool EncodeBase(const std::vector<std::string> &lines,
|
||||
const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *ids,
|
||||
std::vector<std::string> *phrases,
|
||||
std::vector<float> *scores,
|
||||
std::vector<float> *thresholds) {
|
||||
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||
ids->clear();
|
||||
|
||||
std::vector<int32_t> tmp_ids;
|
||||
@@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool has_scores = false;
|
||||
bool has_thresholds = false;
|
||||
bool has_phrases = false;
|
||||
bool has_oov = false;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
for (const auto &line : lines) {
|
||||
float score = 0;
|
||||
float threshold = 0;
|
||||
std::string phrase = "";
|
||||
|
||||
std::istringstream iss(line);
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
word = word.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
if (symbol_table.Contains(word)) {
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
@@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||
"the same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
has_oov = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
thresholds->clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return !has_oov;
|
||||
}
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
std::string word;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
std::string score;
|
||||
std::string phrase;
|
||||
|
||||
std::ostringstream oss;
|
||||
std::istringstream iss(line);
|
||||
while (iss >> word) {
|
||||
switch (word[0]) {
|
||||
case ':': // boosting score for current keyword
|
||||
score = word;
|
||||
break;
|
||||
default:
|
||||
if (!score.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Boosting score should be put after the words/phrase, given "
|
||||
"%s.",
|
||||
line.c_str());
|
||||
return false;
|
||||
}
|
||||
oss << " " << word;
|
||||
break;
|
||||
}
|
||||
}
|
||||
phrase = oss.str().substr(1);
|
||||
std::istringstream piss(phrase);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
while (piss >> word) {
|
||||
if (modeling_unit == "cjkchar") {
|
||||
for (const auto &w : SplitUtf8(word)) {
|
||||
oss << " " << w;
|
||||
}
|
||||
} else if (modeling_unit == "bpe") {
|
||||
std::vector<std::string> bpes;
|
||||
bpe_encoder->Encode(word, &bpes);
|
||||
for (const auto &bpe : bpes) {
|
||||
oss << " " << bpe;
|
||||
}
|
||||
} else {
|
||||
if (modeling_unit != "cjkchar+bpe") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, "
|
||||
"given "
|
||||
"%s",
|
||||
modeling_unit.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
for (const auto &w : SplitUtf8(word)) {
|
||||
if (isalpha(w[0])) {
|
||||
std::vector<std::string> bpes;
|
||||
bpe_encoder->Encode(w, &bpes);
|
||||
for (const auto &bpe : bpes) {
|
||||
oss << " " << bpe;
|
||||
}
|
||||
} else {
|
||||
oss << " " << w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string encoded_phrase = oss.str().substr(1);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << encoded_phrase;
|
||||
if (!score.empty()) {
|
||||
oss << " " << score;
|
||||
}
|
||||
lines.push_back(oss.str());
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
@@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold) {
|
||||
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
while (std::getline(is, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores,
|
||||
threshold);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -25,7 +26,9 @@ namespace sherpa_onnx {
|
||||
* @return If all the symbols from ``is`` are in the symbol_table, returns true
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder_,
|
||||
std::vector<std::vector<int32_t>> *hotwords_id);
|
||||
|
||||
/* Encode the keywords in an input stream to be tokens ids.
|
||||
|
||||
Reference in New Issue
Block a user