Support replacing homonphonic phrases (#2153)
This commit is contained in:
23
.github/scripts/test-offline-ctc.sh
vendored
23
.github/scripts/test-offline-ctc.sh
vendored
@@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do
|
|||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
|
||||||
|
tar xf dict.tar.bz2
|
||||||
|
rm dict.tar.bz2
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
|
||||||
|
|
||||||
|
for m in model.onnx model.int8.onnx; do
|
||||||
|
for use_itn in 0 1; do
|
||||||
|
echo "$m $w $use_itn"
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--sense-voice-model=$repo/$m \
|
||||||
|
--sense-voice-use-itn=$use_itn \
|
||||||
|
--hr-lexicon=./lexicon.txt \
|
||||||
|
--hr-dict-dir=./dict \
|
||||||
|
--hr-rule-fsts=./replace.fst \
|
||||||
|
./test-hr.wav
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf dict replace.fst test-hr.wav lexicon.txt
|
||||||
|
|
||||||
# test wav reader for non-standard wav files
|
# test wav reader for non-standard wav files
|
||||||
waves=(
|
waves=(
|
||||||
|
|||||||
12
.github/scripts/test-python.sh
vendored
12
.github/scripts/test-python.sh
vendored
@@ -95,6 +95,18 @@ rm $name
|
|||||||
ls -lh $repo
|
ls -lh $repo
|
||||||
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
|
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
|
||||||
|
tar xf dict.tar.bz2
|
||||||
|
rm dict.tar.bz2
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
|
||||||
|
|
||||||
|
rm -rf dict replace.fst test-hr.wav lexicon.txt
|
||||||
|
|
||||||
if [[ $(uname) == Linux ]]; then
|
if [[ $(uname) == Linux ]]; then
|
||||||
# It needs ffmpeg
|
# It needs ffmpeg
|
||||||
log "generate subtitles (Chinese)"
|
log "generate subtitles (Chinese)"
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
function(download_kaldifst)
|
function(download_kaldifst)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz")
|
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.13.tar.gz")
|
||||||
set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz")
|
set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.13.tar.gz")
|
||||||
set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f")
|
set(kaldifst_HASH "SHA256=f8dc15fdaf314d7c9c3551ad8c11ed15da0f34de36446798bbd1b90fa7946eb2")
|
||||||
|
|
||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldifst
|
# please pre-download kaldifst
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz
|
$ENV{HOME}/Downloads/kaldifst-1.7.13.tar.gz
|
||||||
${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz
|
${CMAKE_SOURCE_DIR}/kaldifst-1.7.13.tar.gz
|
||||||
${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz
|
${CMAKE_BINARY_DIR}/kaldifst-1.7.13.tar.gz
|
||||||
/tmp/kaldifst-1.7.11.tar.gz
|
/tmp/kaldifst-1.7.13.tar.gz
|
||||||
/star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz
|
/star-fj/fangjun/download/github/kaldifst-1.7.13.tar.gz
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
75
python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
Executable file
75
python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
Executable file
@@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming SenseVoice CTC model from
|
||||||
|
https://github.com/FunAudioLLM/SenseVoice
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
For instance,
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
|
||||||
|
tar xf dict.tar.bz2
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx"
|
||||||
|
tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
|
||||||
|
test_wav = "./test-hr.wav"
|
||||||
|
|
||||||
|
if not Path(model).is_file() or not Path(test_wav).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
and
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/hr-files
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||||
|
model=model,
|
||||||
|
tokens=tokens,
|
||||||
|
use_itn=True,
|
||||||
|
debug=True,
|
||||||
|
hr_lexicon="./lexicon.txt",
|
||||||
|
hr_dict_dir="./dict",
|
||||||
|
hr_rule_fsts="./replace.fst",
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
print(wave_filename)
|
||||||
|
print(stream.result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -20,7 +20,9 @@ set(sources
|
|||||||
features.cc
|
features.cc
|
||||||
file-utils.cc
|
file-utils.cc
|
||||||
fst-utils.cc
|
fst-utils.cc
|
||||||
|
homophone-replacer.cc
|
||||||
hypothesis.cc
|
hypothesis.cc
|
||||||
|
jieba.cc
|
||||||
keyword-spotter-impl.cc
|
keyword-spotter-impl.cc
|
||||||
keyword-spotter.cc
|
keyword-spotter.cc
|
||||||
offline-ctc-fst-decoder-config.cc
|
offline-ctc-fst-decoder-config.cc
|
||||||
|
|||||||
278
sherpa-onnx/csrc/homophone-replacer.cc
Normal file
278
sherpa-onnx/csrc/homophone-replacer.cc
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
// sherpa-onnx/csrc/homophone-replacer.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <strstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
#include "rawfile/raw_file_manager.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "kaldifst/csrc/text-normalizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/jieba.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void HomophoneReplacerConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("hr-dict-dir", &dict_dir,
|
||||||
|
"The dict directory for jieba used by HomophoneReplacer");
|
||||||
|
|
||||||
|
po->Register("hr-lexicon", &lexicon,
|
||||||
|
"Path to lexicon.txt used by HomophoneReplacer.");
|
||||||
|
|
||||||
|
po->Register("hr-rule-fsts", &rule_fsts,
|
||||||
|
"Fst files for HomophoneReplacer. If there are multiple, they "
|
||||||
|
"are separated by a comma. E.g., a.fst,b.fst,c.fst");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HomophoneReplacerConfig::Validate() const {
|
||||||
|
if (!dict_dir.empty()) {
|
||||||
|
std::vector<std::string> required_files = {
|
||||||
|
"jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8",
|
||||||
|
"idf.utf8", "stop_words.utf8",
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const auto &f : required_files) {
|
||||||
|
if (!FileExists(dict_dir + "/" + f)) {
|
||||||
|
SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir",
|
||||||
|
dict_dir.c_str(), f.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!lexicon.empty() && !FileExists(lexicon)) {
|
||||||
|
SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rule_fsts.empty()) {
|
||||||
|
std::vector<std::string> files;
|
||||||
|
SplitStringToVector(rule_fsts, ",", false, &files);
|
||||||
|
|
||||||
|
if (files.size() > 1) {
|
||||||
|
SHERPA_ONNX_LOGE("Only 1 file is supported now.");
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &f : files) {
|
||||||
|
if (!FileExists(f)) {
|
||||||
|
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string HomophoneReplacerConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "HomophoneReplacerConfig(";
|
||||||
|
os << "dict_dir=\"" << dict_dir << "\", ";
|
||||||
|
os << "lexicon=\"" << lexicon << "\", ";
|
||||||
|
os << "rule_fsts=\"" << rule_fsts << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
class HomophoneReplacer::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const HomophoneReplacerConfig &config) : config_(config) {
|
||||||
|
jieba_ = InitJieba(config.dict_dir);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::ifstream is(config.lexicon);
|
||||||
|
InitLexicon(is);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.rule_fsts.empty()) {
|
||||||
|
std::vector<std::string> files;
|
||||||
|
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||||
|
replacer_list_.reserve(files.size());
|
||||||
|
for (const auto &f : files) {
|
||||||
|
if (config.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
|
||||||
|
}
|
||||||
|
replacer_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) {
|
||||||
|
jieba_ = InitJieba(config.dict_dir);
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.lexicon);
|
||||||
|
|
||||||
|
std::istrstream is(buf.data(), buf.size());
|
||||||
|
InitLexicon(is);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.rule_fsts.empty()) {
|
||||||
|
std::vector<std::string> files;
|
||||||
|
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||||
|
replacer_list_.reserve(files.size());
|
||||||
|
for (const auto &f : files) {
|
||||||
|
if (config.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
|
||||||
|
}
|
||||||
|
auto buf = ReadFile(mgr, f);
|
||||||
|
std::istrstream is(buf.data(), buf.size());
|
||||||
|
replacer_list_.push_back(
|
||||||
|
std::make_unique<kaldifst::TextNormalizer>(is));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Apply(const std::string &text) const {
|
||||||
|
bool is_hmm = true;
|
||||||
|
|
||||||
|
std::vector<std::string> words;
|
||||||
|
jieba_->Cut(text, words, is_hmm);
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str());
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "After jieba: ";
|
||||||
|
std::string sep;
|
||||||
|
for (const auto &w : words) {
|
||||||
|
os << sep << w;
|
||||||
|
sep = "_";
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert words to pronunciations
|
||||||
|
std::vector<std::string> pronunciations;
|
||||||
|
|
||||||
|
for (const auto &w : words) {
|
||||||
|
auto p = ConvertWordToPronunciation(w);
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str());
|
||||||
|
}
|
||||||
|
pronunciations.push_back(std::move(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ans;
|
||||||
|
for (const auto &r : replacer_list_) {
|
||||||
|
ans = r->Normalize(words, pronunciations);
|
||||||
|
// TODO(fangjun): We support only 1 rule fst at present.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string ConvertWordToPronunciation(const std::string &word) const {
|
||||||
|
if (word2pron_.count(word)) {
|
||||||
|
return word2pron_.at(word);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (word.size() <= 3) {
|
||||||
|
// not a Chinese character
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> words = SplitUtf8(word);
|
||||||
|
std::string ans;
|
||||||
|
for (const auto &w : words) {
|
||||||
|
if (word2pron_.count(w)) {
|
||||||
|
ans.append(word2pron_.at(w));
|
||||||
|
} else {
|
||||||
|
ans.append(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitLexicon(std::istream &is) {
|
||||||
|
std::string word;
|
||||||
|
std::string pron;
|
||||||
|
std::string p;
|
||||||
|
|
||||||
|
std::string line;
|
||||||
|
int32_t line_num = 0;
|
||||||
|
int32_t num_warn = 0;
|
||||||
|
while (std::getline(is, line)) {
|
||||||
|
++line_num;
|
||||||
|
std::istringstream iss(line);
|
||||||
|
|
||||||
|
pron.clear();
|
||||||
|
iss >> word;
|
||||||
|
ToLowerCase(&word);
|
||||||
|
|
||||||
|
if (word2pron_.count(word)) {
|
||||||
|
num_warn += 1;
|
||||||
|
if (num_warn < 10) {
|
||||||
|
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
|
||||||
|
word.c_str(), line_num, line.c_str());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (iss >> p) {
|
||||||
|
pron.append(std::move(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pron.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Empty pronunciation for word '%s' at line %d:%s. Ignore it.",
|
||||||
|
word.c_str(), line_num, line.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
word2pron_.insert({std::move(word), std::move(pron)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
HomophoneReplacerConfig config_;
|
||||||
|
std::unique_ptr<cppjieba::Jieba> jieba_;
|
||||||
|
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> replacer_list_;
|
||||||
|
std::unordered_map<std::string, std::string> word2pron_;
|
||||||
|
};
|
||||||
|
|
||||||
|
HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
HomophoneReplacer::HomophoneReplacer(Manager *mgr,
|
||||||
|
const HomophoneReplacerConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
|
||||||
|
HomophoneReplacer::~HomophoneReplacer() = default;
|
||||||
|
|
||||||
|
std::string HomophoneReplacer::Apply(const std::string &text) const {
|
||||||
|
return impl_->Apply(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
template HomophoneReplacer::HomophoneReplacer(
|
||||||
|
AAssetManager *mgr, const HomophoneReplacerConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
template HomophoneReplacer::HomophoneReplacer(
|
||||||
|
NativeResourceManager *mgr, const HomophoneReplacerConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
58
sherpa-onnx/csrc/homophone-replacer.h
Normal file
58
sherpa-onnx/csrc/homophone-replacer.h
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
// sherpa-onnx/csrc/homophone-replacer.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct HomophoneReplacerConfig {
|
||||||
|
std::string dict_dir;
|
||||||
|
std::string lexicon;
|
||||||
|
|
||||||
|
// comma separated fst files, e.g. a.fst,b.fst,c.fst
|
||||||
|
std::string rule_fsts;
|
||||||
|
|
||||||
|
bool debug;
|
||||||
|
|
||||||
|
HomophoneReplacerConfig() = default;
|
||||||
|
|
||||||
|
HomophoneReplacerConfig(const std::string &dict_dir,
|
||||||
|
const std::string &lexicon,
|
||||||
|
const std::string &rule_fsts, bool debug)
|
||||||
|
: dict_dir(dict_dir),
|
||||||
|
lexicon(lexicon),
|
||||||
|
rule_fsts(rule_fsts),
|
||||||
|
debug(debug) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class HomophoneReplacer {
|
||||||
|
public:
|
||||||
|
explicit HomophoneReplacer(const HomophoneReplacerConfig &config);
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
HomophoneReplacer(Manager *mgr, const HomophoneReplacerConfig &config);
|
||||||
|
|
||||||
|
~HomophoneReplacer();
|
||||||
|
|
||||||
|
std::string Apply(const std::string &text) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
|
||||||
@@ -19,8 +19,8 @@
|
|||||||
#include "rawfile/raw_file_manager.h"
|
#include "rawfile/raw_file_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "cppjieba/Jieba.hpp"
|
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/jieba.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
@@ -41,20 +41,7 @@ class JiebaLexicon::Impl {
|
|||||||
Impl(const std::string &lexicon, const std::string &tokens,
|
Impl(const std::string &lexicon, const std::string &tokens,
|
||||||
const std::string &dict_dir, bool debug)
|
const std::string &dict_dir, bool debug)
|
||||||
: debug_(debug) {
|
: debug_(debug) {
|
||||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
jieba_ = InitJieba(dict_dir);
|
||||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
|
||||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
|
||||||
std::string idf = dict_dir + "/idf.utf8";
|
|
||||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
|
||||||
|
|
||||||
AssertFileExists(dict);
|
|
||||||
AssertFileExists(hmm);
|
|
||||||
AssertFileExists(user_dict);
|
|
||||||
AssertFileExists(idf);
|
|
||||||
AssertFileExists(stop_word);
|
|
||||||
|
|
||||||
jieba_ =
|
|
||||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
std::ifstream is(tokens);
|
std::ifstream is(tokens);
|
||||||
@@ -71,20 +58,7 @@ class JiebaLexicon::Impl {
|
|||||||
Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens,
|
Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens,
|
||||||
const std::string &dict_dir, bool debug)
|
const std::string &dict_dir, bool debug)
|
||||||
: debug_(debug) {
|
: debug_(debug) {
|
||||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
jieba_ = InitJieba(dict_dir);
|
||||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
|
||||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
|
||||||
std::string idf = dict_dir + "/idf.utf8";
|
|
||||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
|
||||||
|
|
||||||
AssertFileExists(dict);
|
|
||||||
AssertFileExists(hmm);
|
|
||||||
AssertFileExists(user_dict);
|
|
||||||
AssertFileExists(idf);
|
|
||||||
AssertFileExists(stop_word);
|
|
||||||
|
|
||||||
jieba_ =
|
|
||||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
auto buf = ReadFile(mgr, tokens);
|
auto buf = ReadFile(mgr, tokens);
|
||||||
|
|||||||
32
sherpa-onnx/csrc/jieba.cc
Normal file
32
sherpa-onnx/csrc/jieba.cc
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
// sherpa-onnx/csrc/jieba.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/jieba.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir) {
|
||||||
|
if (dict_dir.empty()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||||
|
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||||
|
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||||
|
std::string idf = dict_dir + "/idf.utf8";
|
||||||
|
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||||
|
|
||||||
|
AssertFileExists(dict);
|
||||||
|
AssertFileExists(hmm);
|
||||||
|
AssertFileExists(user_dict);
|
||||||
|
AssertFileExists(idf);
|
||||||
|
AssertFileExists(stop_word);
|
||||||
|
|
||||||
|
return std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf,
|
||||||
|
stop_word);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
18
sherpa-onnx/csrc/jieba.h
Normal file
18
sherpa-onnx/csrc/jieba.h
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// sherpa-onnx/csrc/jieba.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_JIEBA_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_JIEBA_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "cppjieba/Jieba.hpp"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_JIEBA_H_
|
||||||
@@ -22,11 +22,11 @@
|
|||||||
|
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
|
|
||||||
#include "cppjieba/Jieba.hpp"
|
|
||||||
#include "espeak-ng/speak_lib.h"
|
#include "espeak-ng/speak_lib.h"
|
||||||
#include "phoneme_ids.hpp"
|
#include "phoneme_ids.hpp"
|
||||||
#include "phonemize.hpp"
|
#include "phonemize.hpp"
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/jieba.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
@@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl {
|
|||||||
|
|
||||||
InitLexicon(lexicon);
|
InitLexicon(lexicon);
|
||||||
|
|
||||||
InitJieba(dict_dir);
|
jieba_ = InitJieba(dict_dir);
|
||||||
|
|
||||||
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
|
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
|
||||||
}
|
}
|
||||||
@@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl {
|
|||||||
InitLexicon(mgr, lexicon);
|
InitLexicon(mgr, lexicon);
|
||||||
|
|
||||||
// we assume you have copied dict_dir and data_dir from assets to some path
|
// we assume you have copied dict_dir and data_dir from assets to some path
|
||||||
InitJieba(dict_dir);
|
jieba_ = InitJieba(dict_dir);
|
||||||
|
|
||||||
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
|
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
|
||||||
}
|
}
|
||||||
@@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitJieba(const std::string &dict_dir) {
|
|
||||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
|
||||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
|
||||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
|
||||||
std::string idf = dict_dir + "/idf.utf8";
|
|
||||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
|
||||||
|
|
||||||
AssertFileExists(dict);
|
|
||||||
AssertFileExists(hmm);
|
|
||||||
AssertFileExists(user_dict);
|
|
||||||
AssertFileExists(idf);
|
|
||||||
AssertFileExists(stop_word);
|
|
||||||
|
|
||||||
jieba_ =
|
|
||||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineTtsKokoroModelMetaData meta_data_;
|
OfflineTtsKokoroModelMetaData meta_data_;
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,8 @@
|
|||||||
#include "rawfile/raw_file_manager.h"
|
#include "rawfile/raw_file_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "cppjieba/Jieba.hpp"
|
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/jieba.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
@@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl {
|
|||||||
const std::string &dict_dir,
|
const std::string &dict_dir,
|
||||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||||
: meta_data_(meta_data), debug_(debug) {
|
: meta_data_(meta_data), debug_(debug) {
|
||||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
jieba_ = InitJieba(dict_dir);
|
||||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
|
||||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
|
||||||
std::string idf = dict_dir + "/idf.utf8";
|
|
||||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
|
||||||
|
|
||||||
AssertFileExists(dict);
|
|
||||||
AssertFileExists(hmm);
|
|
||||||
AssertFileExists(user_dict);
|
|
||||||
AssertFileExists(idf);
|
|
||||||
AssertFileExists(stop_word);
|
|
||||||
|
|
||||||
jieba_ =
|
|
||||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
std::ifstream is(tokens);
|
std::ifstream is(tokens);
|
||||||
@@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl {
|
|||||||
const std::string &dict_dir,
|
const std::string &dict_dir,
|
||||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||||
: meta_data_(meta_data), debug_(debug) {
|
: meta_data_(meta_data), debug_(debug) {
|
||||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
jieba_ = InitJieba(dict_dir);
|
||||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
|
||||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
|
||||||
std::string idf = dict_dir + "/idf.utf8";
|
|
||||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
|
||||||
|
|
||||||
AssertFileExists(dict);
|
|
||||||
AssertFileExists(hmm);
|
|
||||||
AssertFileExists(user_dict);
|
|
||||||
AssertFileExists(idf);
|
|
||||||
AssertFileExists(stop_word);
|
|
||||||
|
|
||||||
jieba_ =
|
|
||||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
auto buf = ReadFile(mgr, tokens);
|
auto buf = ReadFile(mgr, tokens);
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||||
model_->SubsamplingFactor());
|
model_->SubsamplingFactor());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
ss[i]->SetResult(r);
|
ss[i]->SetResult(r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
|||||||
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
|
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
|
||||||
model_->SubsamplingFactor());
|
model_->SubsamplingFactor());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
s->SetResult(r);
|
s->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
|
|||||||
auto r = Convert(results[0], symbol_table_);
|
auto r = Convert(results[0], symbol_table_);
|
||||||
|
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
s->SetResult(r);
|
s->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
OfflineRecognizerImpl::OfflineRecognizerImpl(
|
OfflineRecognizerImpl::OfflineRecognizerImpl(
|
||||||
const OfflineRecognizerConfig &config)
|
const OfflineRecognizerConfig &config)
|
||||||
: config_(config) {
|
: config_(config) {
|
||||||
|
// TODO(fangjun): Refactor this function
|
||||||
|
|
||||||
if (!config.rule_fsts.empty()) {
|
if (!config.rule_fsts.empty()) {
|
||||||
std::vector<std::string> files;
|
std::vector<std::string> files;
|
||||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||||
@@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
|
|||||||
SHERPA_ONNX_LOGE("FST archives loaded!");
|
SHERPA_ONNX_LOGE("FST archives loaded!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||||
|
!config.hr.rule_fsts.empty()) {
|
||||||
|
auto hr_config = config.hr;
|
||||||
|
hr_config.debug = config.model_config.debug;
|
||||||
|
hr_ = std::make_unique<HomophoneReplacer>(hr_config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Manager>
|
template <typename Manager>
|
||||||
@@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
|
|||||||
} // for (; !reader->Done(); reader->Next())
|
} // for (; !reader->Done(); reader->Next())
|
||||||
} // for (const auto &f : files)
|
} // for (const auto &f : files)
|
||||||
} // if (!config.rule_fars.empty())
|
} // if (!config.rule_fars.empty())
|
||||||
|
|
||||||
|
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||||
|
!config.hr.rule_fsts.empty()) {
|
||||||
|
auto hr_config = config.hr;
|
||||||
|
hr_config.debug = config.model_config.debug;
|
||||||
|
hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
||||||
@@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
|||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string OfflineRecognizerImpl::ApplyHomophoneReplacer(
|
||||||
|
std::string text) const {
|
||||||
|
if (hr_) {
|
||||||
|
text = hr_->Apply(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
|
void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
|
||||||
config_ = config;
|
config_ = config;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldifst/csrc/text-normalizer.h"
|
#include "kaldifst/csrc/text-normalizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||||
@@ -48,12 +49,15 @@ class OfflineRecognizerImpl {
|
|||||||
|
|
||||||
std::string ApplyInverseTextNormalization(std::string text) const;
|
std::string ApplyInverseTextNormalization(std::string text) const;
|
||||||
|
|
||||||
|
std::string ApplyHomophoneReplacer(std::string text) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineRecognizerConfig config_;
|
OfflineRecognizerConfig config_;
|
||||||
// for inverse text normalization. Used only if
|
// for inverse text normalization. Used only if
|
||||||
// config.rule_fsts is not empty or
|
// config.rule_fsts is not empty or
|
||||||
// config.rule_fars is not empty
|
// config.rule_fars is not empty
|
||||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
||||||
|
std::unique_ptr<HomophoneReplacer> hr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
auto r = Convert(results[0], symbol_table_);
|
auto r = Convert(results[0], symbol_table_);
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
s->SetResult(r);
|
s->SetResult(r);
|
||||||
} catch (const Ort::Exception &ex) {
|
} catch (const Ort::Exception &ex) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
auto r = Convert(results[i], symbol_table_);
|
auto r = Convert(results[i], symbol_table_);
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
ss[i]->SetResult(r);
|
ss[i]->SetResult(r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
|||||||
auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
|
auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
|
||||||
frame_shift_ms, subsampling_factor);
|
frame_shift_ms, subsampling_factor);
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
ss[i]->SetResult(r);
|
ss[i]->SetResult(r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
|||||||
subsampling_factor);
|
subsampling_factor);
|
||||||
|
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
s->SetResult(r);
|
s->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||||
model_->SubsamplingFactor());
|
model_->SubsamplingFactor());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
|
|
||||||
ss[i]->SetResult(r);
|
ss[i]->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
|||||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||||
model_->SubsamplingFactor());
|
model_->SubsamplingFactor());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
|
|
||||||
ss[i]->SetResult(r);
|
ss[i]->SetResult(r);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
std::string s = sym_table[i];
|
std::string s = sym_table[i];
|
||||||
s = ApplyInverseTextNormalization(s);
|
s = ApplyInverseTextNormalization(s);
|
||||||
|
s = ApplyHomophoneReplacer(std::move(s));
|
||||||
|
|
||||||
text += s;
|
text += s;
|
||||||
r.tokens.push_back(s);
|
r.tokens.push_back(s);
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
model_config.Register(po);
|
model_config.Register(po);
|
||||||
lm_config.Register(po);
|
lm_config.Register(po);
|
||||||
ctc_fst_decoder_config.Register(po);
|
ctc_fst_decoder_config.Register(po);
|
||||||
|
hr.Register(po);
|
||||||
|
|
||||||
po->Register(
|
po->Register(
|
||||||
"decoding-method", &decoding_method,
|
"decoding-method", &decoding_method,
|
||||||
@@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
|
||||||
|
!hr.Validate()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return model_config.Validate();
|
return model_config.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
|||||||
os << "hotwords_score=" << hotwords_score << ", ";
|
os << "hotwords_score=" << hotwords_score << ", ";
|
||||||
os << "blank_penalty=" << blank_penalty << ", ";
|
os << "blank_penalty=" << blank_penalty << ", ";
|
||||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||||
os << "rule_fars=\"" << rule_fars << "\")";
|
os << "rule_fars=\"" << rule_fars << "\", ";
|
||||||
|
os << "hr=" << hr.ToString() << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
@@ -40,6 +41,7 @@ struct OfflineRecognizerConfig {
|
|||||||
|
|
||||||
// If there are multiple FST archives, they are applied from left to right.
|
// If there are multiple FST archives, they are applied from left to right.
|
||||||
std::string rule_fars;
|
std::string rule_fars;
|
||||||
|
HomophoneReplacerConfig hr;
|
||||||
|
|
||||||
// only greedy_search is implemented
|
// only greedy_search is implemented
|
||||||
// TODO(fangjun): Implement modified_beam_search
|
// TODO(fangjun): Implement modified_beam_search
|
||||||
@@ -52,7 +54,7 @@ struct OfflineRecognizerConfig {
|
|||||||
const std::string &decoding_method, int32_t max_active_paths,
|
const std::string &decoding_method, int32_t max_active_paths,
|
||||||
const std::string &hotwords_file, float hotwords_score,
|
const std::string &hotwords_file, float hotwords_score,
|
||||||
float blank_penalty, const std::string &rule_fsts,
|
float blank_penalty, const std::string &rule_fsts,
|
||||||
const std::string &rule_fars)
|
const std::string &rule_fars, const HomophoneReplacerConfig &hr)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
lm_config(lm_config),
|
lm_config(lm_config),
|
||||||
@@ -63,7 +65,8 @@ struct OfflineRecognizerConfig {
|
|||||||
hotwords_score(hotwords_score),
|
hotwords_score(hotwords_score),
|
||||||
blank_penalty(blank_penalty),
|
blank_penalty(blank_penalty),
|
||||||
rule_fsts(rule_fsts),
|
rule_fsts(rule_fsts),
|
||||||
rule_fars(rule_fars) {}
|
rule_fars(rule_fars),
|
||||||
|
hr(hr) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
|||||||
auto r =
|
auto r =
|
||||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||||
r.text = ApplyInverseTextNormalization(r.text);
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
|
|||||||
SHERPA_ONNX_LOGE("FST archives loaded!");
|
SHERPA_ONNX_LOGE("FST archives loaded!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||||
|
!config.hr.rule_fsts.empty()) {
|
||||||
|
auto hr_config = config.hr;
|
||||||
|
hr_config.debug = config.model_config.debug;
|
||||||
|
hr_ = std::make_unique<HomophoneReplacer>(hr_config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Manager>
|
template <typename Manager>
|
||||||
@@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr,
|
|||||||
} // for (; !reader->Done(); reader->Next())
|
} // for (; !reader->Done(); reader->Next())
|
||||||
} // for (const auto &f : files)
|
} // for (const auto &f : files)
|
||||||
} // if (!config.rule_fars.empty())
|
} // if (!config.rule_fars.empty())
|
||||||
|
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||||
|
!config.hr.rule_fsts.empty()) {
|
||||||
|
auto hr_config = config.hr;
|
||||||
|
hr_config.debug = config.model_config.debug;
|
||||||
|
hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
|
std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
|
||||||
@@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
|
|||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string OnlineRecognizerImpl::ApplyHomophoneReplacer(
|
||||||
|
std::string text) const {
|
||||||
|
if (hr_) {
|
||||||
|
text = hr_->Apply(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
template OnlineRecognizerImpl::OnlineRecognizerImpl(
|
template OnlineRecognizerImpl::OnlineRecognizerImpl(
|
||||||
AAssetManager *mgr, const OnlineRecognizerConfig &config);
|
AAssetManager *mgr, const OnlineRecognizerConfig &config);
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldifst/csrc/text-normalizer.h"
|
#include "kaldifst/csrc/text-normalizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||||
#include "sherpa-onnx/csrc/online-stream.h"
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
@@ -57,6 +58,7 @@ class OnlineRecognizerImpl {
|
|||||||
virtual void Reset(OnlineStream *s) const = 0;
|
virtual void Reset(OnlineStream *s) const = 0;
|
||||||
|
|
||||||
std::string ApplyInverseTextNormalization(std::string text) const;
|
std::string ApplyInverseTextNormalization(std::string text) const;
|
||||||
|
std::string ApplyHomophoneReplacer(std::string text) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OnlineRecognizerConfig config_;
|
OnlineRecognizerConfig config_;
|
||||||
@@ -64,6 +66,7 @@ class OnlineRecognizerImpl {
|
|||||||
// config.rule_fsts is not empty or
|
// config.rule_fsts is not empty or
|
||||||
// config.rule_fars is not empty
|
// config.rule_fars is not empty
|
||||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
||||||
|
std::unique_ptr<HomophoneReplacer> hr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
|
|||||||
auto decoder_result = s->GetParaformerResult();
|
auto decoder_result = s->GetParaformerResult();
|
||||||
|
|
||||||
auto r = Convert(decoder_result, sym_);
|
auto r = Convert(decoder_result, sym_);
|
||||||
r.text = ApplyInverseTextNormalization(r.text);
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
// (the encoder state buffers are kept)
|
// (the encoder state buffers are kept)
|
||||||
for (const auto &it : last_result.hyps) {
|
for (const auto &it : last_result.hyps) {
|
||||||
auto h = it.second;
|
auto h = it.second;
|
||||||
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
|
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, h.ys.end()),
|
||||||
h.ys.end()),
|
|
||||||
h.log_prob});
|
h.log_prob});
|
||||||
}
|
}
|
||||||
|
|
||||||
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
|
r.tokens = std::vector<int64_t>(last_result.tokens.end() - context_size,
|
||||||
last_result.tokens.end());
|
last_result.tokens.end());
|
||||||
} else {
|
} else {
|
||||||
if(config_.reset_encoder) {
|
if (config_.reset_encoder) {
|
||||||
// reset encoder states, use blanks as 'ys' context
|
// reset encoder states, use blanks as 'ys' context
|
||||||
s->SetStates(model_->GetEncoderInitStates());
|
s->SetStates(model_->GetEncoderInitStates());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
|||||||
subsampling_factor, s->GetCurrentSegment(),
|
subsampling_factor, s->GetCurrentSegment(),
|
||||||
s->GetNumFramesSinceStart());
|
s->GetNumFramesSinceStart());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
endpoint_config.Register(po);
|
endpoint_config.Register(po);
|
||||||
lm_config.Register(po);
|
lm_config.Register(po);
|
||||||
ctc_fst_decoder_config.Register(po);
|
ctc_fst_decoder_config.Register(po);
|
||||||
|
hr.Register(po);
|
||||||
|
|
||||||
po->Register("enable-endpoint", &enable_endpoint,
|
po->Register("enable-endpoint", &enable_endpoint,
|
||||||
"True to enable endpoint detection. False to disable it.");
|
"True to enable endpoint detection. False to disable it.");
|
||||||
@@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
|
||||||
|
!hr.Validate()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return model_config.Validate();
|
return model_config.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
|||||||
os << "temperature_scale=" << temperature_scale << ", ";
|
os << "temperature_scale=" << temperature_scale << ", ";
|
||||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||||
os << "rule_fars=\"" << rule_fars << "\", ";
|
os << "rule_fars=\"" << rule_fars << "\", ";
|
||||||
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
|
os << "reset_encoder=" << (reset_encoder ? "True" : "False") << ", ";
|
||||||
|
os << "hr=" << hr.ToString() << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/endpoint.h"
|
#include "sherpa-onnx/csrc/endpoint.h"
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
@@ -107,6 +108,8 @@ struct OnlineRecognizerConfig {
|
|||||||
// currently only in `OnlineRecognizerTransducerImpl`.
|
// currently only in `OnlineRecognizerTransducerImpl`.
|
||||||
bool reset_encoder = false;
|
bool reset_encoder = false;
|
||||||
|
|
||||||
|
HomophoneReplacerConfig hr;
|
||||||
|
|
||||||
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
||||||
/// the hotwords will be loaded from the buffered string instead of from the
|
/// the hotwords will be loaded from the buffered string instead of from the
|
||||||
/// "hotwords_file"
|
/// "hotwords_file"
|
||||||
@@ -123,7 +126,7 @@ struct OnlineRecognizerConfig {
|
|||||||
int32_t max_active_paths, const std::string &hotwords_file,
|
int32_t max_active_paths, const std::string &hotwords_file,
|
||||||
float hotwords_score, float blank_penalty, float temperature_scale,
|
float hotwords_score, float blank_penalty, float temperature_scale,
|
||||||
const std::string &rule_fsts, const std::string &rule_fars,
|
const std::string &rule_fsts, const std::string &rule_fars,
|
||||||
bool reset_encoder)
|
bool reset_encoder, const HomophoneReplacerConfig &hr)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
lm_config(lm_config),
|
lm_config(lm_config),
|
||||||
@@ -138,7 +141,8 @@ struct OnlineRecognizerConfig {
|
|||||||
temperature_scale(temperature_scale),
|
temperature_scale(temperature_scale),
|
||||||
rule_fsts(rule_fsts),
|
rule_fsts(rule_fsts),
|
||||||
rule_fars(rule_fars),
|
rule_fars(rule_fars),
|
||||||
reset_encoder(reset_encoder) {}
|
reset_encoder(reset_encoder),
|
||||||
|
hr(hr) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
|
|||||||
auto r =
|
auto r =
|
||||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||||
r.text = ApplyInverseTextNormalization(r.text);
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
|
|||||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ set(srcs
|
|||||||
display.cc
|
display.cc
|
||||||
endpoint.cc
|
endpoint.cc
|
||||||
features.cc
|
features.cc
|
||||||
|
homophone-replacer.cc
|
||||||
keyword-spotter.cc
|
keyword-spotter.cc
|
||||||
offline-ctc-fst-decoder-config.cc
|
offline-ctc-fst-decoder-config.cc
|
||||||
offline-dolphin-model-config.cc
|
offline-dolphin-model-config.cc
|
||||||
|
|||||||
28
sherpa-onnx/python/csrc/homophone-replacer.cc
Normal file
28
sherpa-onnx/python/csrc/homophone-replacer.cc
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/python/csrc/homophone-replacer.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindHomophoneReplacer(py::module *m) {
|
||||||
|
using PyClass = HomophoneReplacerConfig;
|
||||||
|
py::class_<PyClass>(*m, "HomophoneReplacerConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const std::string &, const std::string &,
|
||||||
|
const std::string &, bool>(),
|
||||||
|
py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"),
|
||||||
|
py::arg("debug") = false)
|
||||||
|
.def_readwrite("dict_dir", &PyClass::dict_dir)
|
||||||
|
.def_readwrite("lexicon", &PyClass::lexicon)
|
||||||
|
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||||
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/homophone-replacer.h
Normal file
16
sherpa-onnx/python/csrc/homophone-replacer.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/homophone-replacer.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindHomophoneReplacer(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
|
||||||
@@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
|
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
|
||||||
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
|
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
|
||||||
const std::string &, int32_t, const std::string &, float,
|
const std::string &, int32_t, const std::string &, float,
|
||||||
float, const std::string &, const std::string &>(),
|
float, const std::string &, const std::string &,
|
||||||
|
const HomophoneReplacerConfig &>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OfflineLMConfig(),
|
py::arg("lm_config") = OfflineLMConfig(),
|
||||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||||
py::arg("decoding_method") = "greedy_search",
|
py::arg("decoding_method") = "greedy_search",
|
||||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||||
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
|
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
|
||||||
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
|
py::arg("rule_fsts") = "", py::arg("rule_fars") = "",
|
||||||
|
py::arg("hr") = HomophoneReplacerConfig{})
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
@@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||||
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
||||||
|
.def_readwrite("hr", &PyClass::hr)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
const OnlineLMConfig &, const EndpointConfig &,
|
const OnlineLMConfig &, const EndpointConfig &,
|
||||||
const OnlineCtcFstDecoderConfig &, bool,
|
const OnlineCtcFstDecoderConfig &, bool,
|
||||||
const std::string &, int32_t, const std::string &, float,
|
const std::string &, int32_t, const std::string &, float,
|
||||||
float, float, const std::string &, const std::string &, bool>(),
|
float, float, const std::string &, const std::string &,
|
||||||
|
bool, const HomophoneReplacerConfig &>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OnlineLMConfig(),
|
py::arg("lm_config") = OnlineLMConfig(),
|
||||||
py::arg("endpoint_config") = EndpointConfig(),
|
py::arg("endpoint_config") = EndpointConfig(),
|
||||||
@@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
|
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
|
||||||
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
|
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
|
||||||
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
|
py::arg("rule_fars") = "", py::arg("reset_encoder") = false,
|
||||||
|
py::arg("hr") = HomophoneReplacerConfig{})
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
@@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||||
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
||||||
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
|
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
|
||||||
|
.def_readwrite("hr", &PyClass::hr)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "sherpa-onnx/python/csrc/display.h"
|
#include "sherpa-onnx/python/csrc/display.h"
|
||||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||||
#include "sherpa-onnx/python/csrc/features.h"
|
#include "sherpa-onnx/python/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
|
||||||
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
|
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
|
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||||
@@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
|||||||
PybindAudioTagging(&m);
|
PybindAudioTagging(&m);
|
||||||
PybindOfflinePunctuation(&m);
|
PybindOfflinePunctuation(&m);
|
||||||
PybindOnlinePunctuation(&m);
|
PybindOnlinePunctuation(&m);
|
||||||
|
PybindHomophoneReplacer(&m);
|
||||||
|
|
||||||
PybindFeatures(&m);
|
PybindFeatures(&m);
|
||||||
PybindOnlineCtcFstDecoderConfig(&m);
|
PybindOnlineCtcFstDecoderConfig(&m);
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
FeatureExtractorConfig,
|
FeatureExtractorConfig,
|
||||||
|
HomophoneReplacerConfig,
|
||||||
OfflineCtcFstDecoderConfig,
|
OfflineCtcFstDecoderConfig,
|
||||||
OfflineDolphinModelConfig,
|
OfflineDolphinModelConfig,
|
||||||
OfflineFireRedAsrModelConfig,
|
OfflineFireRedAsrModelConfig,
|
||||||
@@ -64,6 +65,9 @@ class OfflineRecognizer(object):
|
|||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
lm: str = "",
|
lm: str = "",
|
||||||
lm_scale: float = 0.1,
|
lm_scale: float = 0.1,
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -181,6 +185,11 @@ class OfflineRecognizer(object):
|
|||||||
blank_penalty=blank_penalty,
|
blank_penalty=blank_penalty,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -201,6 +210,9 @@ class OfflineRecognizer(object):
|
|||||||
use_itn: bool = False,
|
use_itn: bool = False,
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -263,6 +275,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -281,6 +298,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -336,6 +356,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -354,6 +379,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -411,6 +439,9 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -429,6 +460,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -483,6 +517,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -501,6 +540,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -557,6 +599,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -577,6 +624,9 @@ class OfflineRecognizer(object):
|
|||||||
tail_paddings: int = -1,
|
tail_paddings: int = -1,
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -647,6 +697,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -664,6 +719,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -719,6 +777,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -738,6 +801,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -800,6 +866,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -818,6 +889,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -873,6 +947,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
@@ -891,6 +970,9 @@ class OfflineRecognizer(object):
|
|||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -947,6 +1029,11 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
|
|||||||
@@ -3,25 +3,26 @@ from pathlib import Path
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
|
CudaConfig,
|
||||||
EndpointConfig,
|
EndpointConfig,
|
||||||
FeatureExtractorConfig,
|
FeatureExtractorConfig,
|
||||||
|
HomophoneReplacerConfig,
|
||||||
|
OnlineCtcFstDecoderConfig,
|
||||||
OnlineLMConfig,
|
OnlineLMConfig,
|
||||||
OnlineModelConfig,
|
OnlineModelConfig,
|
||||||
|
OnlineNeMoCtcModelConfig,
|
||||||
OnlineParaformerModelConfig,
|
OnlineParaformerModelConfig,
|
||||||
)
|
)
|
||||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
CudaConfig,
|
|
||||||
TensorrtConfig,
|
|
||||||
ProviderConfig,
|
|
||||||
OnlineRecognizerConfig,
|
OnlineRecognizerConfig,
|
||||||
OnlineRecognizerResult,
|
OnlineRecognizerResult,
|
||||||
OnlineStream,
|
OnlineStream,
|
||||||
OnlineTransducerModelConfig,
|
OnlineTransducerModelConfig,
|
||||||
OnlineWenetCtcModelConfig,
|
OnlineWenetCtcModelConfig,
|
||||||
OnlineNeMoCtcModelConfig,
|
|
||||||
OnlineZipformer2CtcModelConfig,
|
OnlineZipformer2CtcModelConfig,
|
||||||
OnlineCtcFstDecoderConfig,
|
ProviderConfig,
|
||||||
|
TensorrtConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,9 +83,12 @@ class OnlineRecognizer(object):
|
|||||||
trt_detailed_build_log: bool = False,
|
trt_detailed_build_log: bool = False,
|
||||||
trt_engine_cache_enable: bool = True,
|
trt_engine_cache_enable: bool = True,
|
||||||
trt_timing_cache_enable: bool = True,
|
trt_timing_cache_enable: bool = True,
|
||||||
trt_engine_cache_path: str ="",
|
trt_engine_cache_path: str = "",
|
||||||
trt_timing_cache_path: str ="",
|
trt_timing_cache_path: str = "",
|
||||||
trt_dump_subgraphs: bool = False,
|
trt_dump_subgraphs: bool = False,
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -228,27 +232,27 @@ class OnlineRecognizer(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
cuda_config = CudaConfig(
|
cuda_config = CudaConfig(
|
||||||
cudnn_conv_algo_search=cudnn_conv_algo_search,
|
cudnn_conv_algo_search=cudnn_conv_algo_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
trt_config = TensorrtConfig(
|
trt_config = TensorrtConfig(
|
||||||
trt_max_workspace_size=trt_max_workspace_size,
|
trt_max_workspace_size=trt_max_workspace_size,
|
||||||
trt_max_partition_iterations=trt_max_partition_iterations,
|
trt_max_partition_iterations=trt_max_partition_iterations,
|
||||||
trt_min_subgraph_size=trt_min_subgraph_size,
|
trt_min_subgraph_size=trt_min_subgraph_size,
|
||||||
trt_fp16_enable=trt_fp16_enable,
|
trt_fp16_enable=trt_fp16_enable,
|
||||||
trt_detailed_build_log=trt_detailed_build_log,
|
trt_detailed_build_log=trt_detailed_build_log,
|
||||||
trt_engine_cache_enable=trt_engine_cache_enable,
|
trt_engine_cache_enable=trt_engine_cache_enable,
|
||||||
trt_timing_cache_enable=trt_timing_cache_enable,
|
trt_timing_cache_enable=trt_timing_cache_enable,
|
||||||
trt_engine_cache_path=trt_engine_cache_path,
|
trt_engine_cache_path=trt_engine_cache_path,
|
||||||
trt_timing_cache_path=trt_timing_cache_path,
|
trt_timing_cache_path=trt_timing_cache_path,
|
||||||
trt_dump_subgraphs=trt_dump_subgraphs,
|
trt_dump_subgraphs=trt_dump_subgraphs,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_config = ProviderConfig(
|
provider_config = ProviderConfig(
|
||||||
trt_config=trt_config,
|
trt_config=trt_config,
|
||||||
cuda_config=cuda_config,
|
cuda_config=cuda_config,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = OnlineModelConfig(
|
model_config = OnlineModelConfig(
|
||||||
@@ -311,6 +315,11 @@ class OnlineRecognizer(object):
|
|||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
reset_encoder=reset_encoder,
|
reset_encoder=reset_encoder,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
@@ -336,6 +345,9 @@ class OnlineRecognizer(object):
|
|||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
device: int = 0,
|
device: int = 0,
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -402,8 +414,8 @@ class OnlineRecognizer(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider_config = ProviderConfig(
|
provider_config = ProviderConfig(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = OnlineModelConfig(
|
model_config = OnlineModelConfig(
|
||||||
@@ -434,6 +446,11 @@ class OnlineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
@@ -460,6 +477,9 @@ class OnlineRecognizer(object):
|
|||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
device: int = 0,
|
device: int = 0,
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -526,8 +546,8 @@ class OnlineRecognizer(object):
|
|||||||
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
|
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
|
||||||
|
|
||||||
provider_config = ProviderConfig(
|
provider_config = ProviderConfig(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = OnlineModelConfig(
|
model_config = OnlineModelConfig(
|
||||||
@@ -563,6 +583,11 @@ class OnlineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
@@ -587,6 +612,9 @@ class OnlineRecognizer(object):
|
|||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
device: int = 0,
|
device: int = 0,
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -650,8 +678,8 @@ class OnlineRecognizer(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider_config = ProviderConfig(
|
provider_config = ProviderConfig(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = OnlineModelConfig(
|
model_config = OnlineModelConfig(
|
||||||
@@ -681,6 +709,11 @@ class OnlineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
@@ -707,6 +740,9 @@ class OnlineRecognizer(object):
|
|||||||
rule_fsts: str = "",
|
rule_fsts: str = "",
|
||||||
rule_fars: str = "",
|
rule_fars: str = "",
|
||||||
device: int = 0,
|
device: int = 0,
|
||||||
|
hr_dict_dir: str = "",
|
||||||
|
hr_rule_fsts: str = "",
|
||||||
|
hr_lexicon: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -775,8 +811,8 @@ class OnlineRecognizer(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider_config = ProviderConfig(
|
provider_config = ProviderConfig(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = OnlineModelConfig(
|
model_config = OnlineModelConfig(
|
||||||
@@ -806,6 +842,11 @@ class OnlineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
rule_fsts=rule_fsts,
|
rule_fsts=rule_fsts,
|
||||||
rule_fars=rule_fars,
|
rule_fars=rule_fars,
|
||||||
|
hr=HomophoneReplacerConfig(
|
||||||
|
dict_dir=hr_dict_dir,
|
||||||
|
lexicon=hr_lexicon,
|
||||||
|
rule_fsts=hr_rule_fsts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user