This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/csrc/homophone-replacer.cc
2025-05-14 10:58:35 +08:00

313 lines
8.2 KiB
C++

// sherpa-onnx/csrc/homophone-replacer.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/homophone-replacer.h"
#include <fstream>
#include <sstream>
#include <string>
#include <strstream>
#include <unordered_map>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/jieba.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void HomophoneReplacerConfig::Register(ParseOptions *po) {
po->Register("hr-dict-dir", &dict_dir,
"The dict directory for jieba used by HomophoneReplacer");
po->Register("hr-lexicon", &lexicon,
"Path to lexicon.txt used by HomophoneReplacer.");
po->Register("hr-rule-fsts", &rule_fsts,
"Fst files for HomophoneReplacer. If there are multiple, they "
"are separated by a comma. E.g., a.fst,b.fst,c.fst");
}
bool HomophoneReplacerConfig::Validate() const {
if (!dict_dir.empty()) {
std::vector<std::string> required_files = {
"jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8",
"idf.utf8", "stop_words.utf8",
};
for (const auto &f : required_files) {
if (!FileExists(dict_dir + "/" + f)) {
SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir",
dict_dir.c_str(), f.c_str());
return false;
}
}
}
if (!lexicon.empty() && !FileExists(lexicon)) {
SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str());
return false;
}
if (!rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fsts, ",", false, &files);
if (files.size() > 1) {
SHERPA_ONNX_LOGE("Only 1 file is supported now.");
SHERPA_ONNX_EXIT(-1);
}
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
return false;
}
}
}
return true;
}
std::string HomophoneReplacerConfig::ToString() const {
std::ostringstream os;
os << "HomophoneReplacerConfig(";
os << "dict_dir=\"" << dict_dir << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "rule_fsts=\"" << rule_fsts << "\")";
return os.str();
}
class HomophoneReplacer::Impl {
public:
explicit Impl(const HomophoneReplacerConfig &config) : config_(config) {
jieba_ = InitJieba(config.dict_dir);
{
std::ifstream is(config.lexicon);
InitLexicon(is);
}
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
replacer_list_.reserve(files.size());
for (const auto &f : files) {
if (config.debug) {
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
}
replacer_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
}
template <typename Manager>
Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) {
jieba_ = InitJieba(config.dict_dir);
{
auto buf = ReadFile(mgr, config.lexicon);
std::istrstream is(buf.data(), buf.size());
InitLexicon(is);
}
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
replacer_list_.reserve(files.size());
for (const auto &f : files) {
if (config.debug) {
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
replacer_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(is));
}
}
}
std::string Apply(const std::string &text) const {
std::string ans;
if (text.empty()) {
return ans;
}
bool is_hmm = true;
std::vector<std::string> words;
jieba_->Cut(text, words, is_hmm);
if (config_.debug) {
SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str());
std::ostringstream os;
os << "After jieba: ";
std::string sep;
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
// convert words to pronunciations
std::vector<std::string> current_words;
std::vector<std::string> current_pronunciations;
for (const auto &w : words) {
if (w.size() < 3 ||
reinterpret_cast<const uint8_t *>(w.data())[0] < 128) {
if (!current_words.empty()) {
ans += ApplyImpl(current_words, current_pronunciations);
current_words.clear();
current_pronunciations.clear();
}
ans += w;
continue;
}
auto p = ConvertWordToPronunciation(w);
if (config_.debug) {
SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str());
}
current_words.push_back(w);
current_pronunciations.push_back(std::move(p));
}
if (!current_words.empty()) {
ans += ApplyImpl(current_words, current_pronunciations);
}
if (config_.debug) {
SHERPA_ONNX_LOGE("Output text: '%s'", ans.c_str());
}
return ans;
}
private:
std::string ApplyImpl(const std::vector<std::string> &words,
const std::vector<std::string> &pronunciations) const {
std::string ans;
for (const auto &r : replacer_list_) {
ans = r->Normalize(words, pronunciations);
// TODO(fangjun): We support only 1 rule fst at present.
break;
}
return ans;
}
std::string ConvertWordToPronunciation(const std::string &word) const {
if (word2pron_.count(word)) {
return word2pron_.at(word);
}
if (word.size() <= 3) {
// not a Chinese character
return word;
}
std::vector<std::string> words = SplitUtf8(word);
std::string ans;
for (const auto &w : words) {
if (word2pron_.count(w)) {
ans.append(word2pron_.at(w));
} else {
ans.append(w);
}
}
return ans;
}
void InitLexicon(std::istream &is) {
std::string word;
std::string pron;
std::string p;
std::string line;
int32_t line_num = 0;
int32_t num_warn = 0;
while (std::getline(is, line)) {
++line_num;
std::istringstream iss(line);
pron.clear();
iss >> word;
ToLowerCase(&word);
if (word2pron_.count(word)) {
num_warn += 1;
if (num_warn < 10) {
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
}
continue;
}
while (iss >> p) {
if (p.back() > '4') {
p.push_back('1');
}
pron.append(std::move(p));
}
if (pron.empty()) {
SHERPA_ONNX_LOGE(
"Empty pronunciation for word '%s' at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
continue;
}
word2pron_.insert({std::move(word), std::move(pron)});
}
}
private:
HomophoneReplacerConfig config_;
std::unique_ptr<cppjieba::Jieba> jieba_;
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> replacer_list_;
std::unordered_map<std::string, std::string> word2pron_;
};
HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
HomophoneReplacer::HomophoneReplacer(Manager *mgr,
const HomophoneReplacerConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
HomophoneReplacer::~HomophoneReplacer() = default;
std::string HomophoneReplacer::Apply(const std::string &text) const {
return RemoveInvalidUtf8Sequences(impl_->Apply(text));
}
#if __ANDROID_API__ >= 9
template HomophoneReplacer::HomophoneReplacer(
AAssetManager *mgr, const HomophoneReplacerConfig &config);
#endif
#if __OHOS__
template HomophoneReplacer::HomophoneReplacer(
NativeResourceManager *mgr, const HomophoneReplacerConfig &config);
#endif
} // namespace sherpa_onnx