@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -12,15 +13,16 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
static bool EncodeBase(const std::vector<std::string> &lines,
|
||||
const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *ids,
|
||||
std::vector<std::string> *phrases,
|
||||
std::vector<float> *scores,
|
||||
std::vector<float> *thresholds) {
|
||||
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||
ids->clear();
|
||||
|
||||
std::vector<int32_t> tmp_ids;
|
||||
@@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool has_scores = false;
|
||||
bool has_thresholds = false;
|
||||
bool has_phrases = false;
|
||||
bool has_oov = false;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
for (const auto &line : lines) {
|
||||
float score = 0;
|
||||
float threshold = 0;
|
||||
std::string phrase = "";
|
||||
|
||||
std::istringstream iss(line);
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
word = word.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
if (symbol_table.Contains(word)) {
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
@@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||
"the same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
has_oov = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
thresholds->clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return !has_oov;
|
||||
}
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
std::string word;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
std::string score;
|
||||
std::string phrase;
|
||||
|
||||
std::ostringstream oss;
|
||||
std::istringstream iss(line);
|
||||
while (iss >> word) {
|
||||
switch (word[0]) {
|
||||
case ':': // boosting score for current keyword
|
||||
score = word;
|
||||
break;
|
||||
default:
|
||||
if (!score.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Boosting score should be put after the words/phrase, given "
|
||||
"%s.",
|
||||
line.c_str());
|
||||
return false;
|
||||
}
|
||||
oss << " " << word;
|
||||
break;
|
||||
}
|
||||
}
|
||||
phrase = oss.str().substr(1);
|
||||
std::istringstream piss(phrase);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
while (piss >> word) {
|
||||
if (modeling_unit == "cjkchar") {
|
||||
for (const auto &w : SplitUtf8(word)) {
|
||||
oss << " " << w;
|
||||
}
|
||||
} else if (modeling_unit == "bpe") {
|
||||
std::vector<std::string> bpes;
|
||||
bpe_encoder->Encode(word, &bpes);
|
||||
for (const auto &bpe : bpes) {
|
||||
oss << " " << bpe;
|
||||
}
|
||||
} else {
|
||||
if (modeling_unit != "cjkchar+bpe") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, "
|
||||
"given "
|
||||
"%s",
|
||||
modeling_unit.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
for (const auto &w : SplitUtf8(word)) {
|
||||
if (isalpha(w[0])) {
|
||||
std::vector<std::string> bpes;
|
||||
bpe_encoder->Encode(w, &bpes);
|
||||
for (const auto &bpe : bpes) {
|
||||
oss << " " << bpe;
|
||||
}
|
||||
} else {
|
||||
oss << " " << w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string encoded_phrase = oss.str().substr(1);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << encoded_phrase;
|
||||
if (!score.empty()) {
|
||||
oss << " " << score;
|
||||
}
|
||||
lines.push_back(oss.str());
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
@@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold) {
|
||||
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
while (std::getline(is, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores,
|
||||
threshold);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user