Support replacing homonphonic phrases (#2153)
This commit is contained in:
23
.github/scripts/test-offline-ctc.sh
vendored
23
.github/scripts/test-offline-ctc.sh
vendored
@@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do
|
||||
done
|
||||
done
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
|
||||
tar xf dict.tar.bz2
|
||||
rm dict.tar.bz2
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
|
||||
|
||||
for m in model.onnx model.int8.onnx; do
|
||||
for use_itn in 0 1; do
|
||||
echo "$m $w $use_itn"
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--sense-voice-model=$repo/$m \
|
||||
--sense-voice-use-itn=$use_itn \
|
||||
--hr-lexicon=./lexicon.txt \
|
||||
--hr-dict-dir=./dict \
|
||||
--hr-rule-fsts=./replace.fst \
|
||||
./test-hr.wav
|
||||
done
|
||||
done
|
||||
|
||||
rm -rf dict replace.fst test-hr.wav lexicon.txt
|
||||
|
||||
# test wav reader for non-standard wav files
|
||||
waves=(
|
||||
|
||||
12
.github/scripts/test-python.sh
vendored
12
.github/scripts/test-python.sh
vendored
@@ -95,6 +95,18 @@ rm $name
|
||||
ls -lh $repo
|
||||
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
|
||||
tar xf dict.tar.bz2
|
||||
rm dict.tar.bz2
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
|
||||
|
||||
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
|
||||
|
||||
rm -rf dict replace.fst test-hr.wav lexicon.txt
|
||||
|
||||
if [[ $(uname) == Linux ]]; then
|
||||
# It needs ffmpeg
|
||||
log "generate subtitles (Chinese)"
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
function(download_kaldifst)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz")
|
||||
set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz")
|
||||
set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f")
|
||||
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.13.tar.gz")
|
||||
set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.13.tar.gz")
|
||||
set(kaldifst_HASH "SHA256=f8dc15fdaf314d7c9c3551ad8c11ed15da0f34de36446798bbd1b90fa7946eb2")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldifst
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz
|
||||
${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz
|
||||
/tmp/kaldifst-1.7.11.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz
|
||||
$ENV{HOME}/Downloads/kaldifst-1.7.13.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/kaldifst-1.7.13.tar.gz
|
||||
${CMAKE_BINARY_DIR}/kaldifst-1.7.13.tar.gz
|
||||
/tmp/kaldifst-1.7.13.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldifst-1.7.13.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
|
||||
75
python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
Executable file
75
python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
Executable file
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a non-streaming SenseVoice CTC model from
|
||||
https://github.com/FunAudioLLM/SenseVoice
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
For instance,
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
|
||||
tar xf dict.tar.bz2
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx"
|
||||
tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
|
||||
test_wav = "./test-hr.wav"
|
||||
|
||||
if not Path(model).is_file() or not Path(test_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
and
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/hr-files
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
model=model,
|
||||
tokens=tokens,
|
||||
use_itn=True,
|
||||
debug=True,
|
||||
hr_lexicon="./lexicon.txt",
|
||||
hr_dict_dir="./dict",
|
||||
hr_rule_fsts="./replace.fst",
|
||||
),
|
||||
test_wav,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, wave_filename = create_recognizer()
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||
# sample_rate does not need to be 16000 Hz
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
print(wave_filename)
|
||||
print(stream.result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -20,7 +20,9 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
fst-utils.cc
|
||||
homophone-replacer.cc
|
||||
hypothesis.cc
|
||||
jieba.cc
|
||||
keyword-spotter-impl.cc
|
||||
keyword-spotter.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
|
||||
278
sherpa-onnx/csrc/homophone-replacer.cc
Normal file
278
sherpa-onnx/csrc/homophone-replacer.cc
Normal file
@@ -0,0 +1,278 @@
|
||||
// sherpa-onnx/csrc/homophone-replacer.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <strstream>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/jieba.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void HomophoneReplacerConfig::Register(ParseOptions *po) {
|
||||
po->Register("hr-dict-dir", &dict_dir,
|
||||
"The dict directory for jieba used by HomophoneReplacer");
|
||||
|
||||
po->Register("hr-lexicon", &lexicon,
|
||||
"Path to lexicon.txt used by HomophoneReplacer.");
|
||||
|
||||
po->Register("hr-rule-fsts", &rule_fsts,
|
||||
"Fst files for HomophoneReplacer. If there are multiple, they "
|
||||
"are separated by a comma. E.g., a.fst,b.fst,c.fst");
|
||||
}
|
||||
|
||||
bool HomophoneReplacerConfig::Validate() const {
|
||||
if (!dict_dir.empty()) {
|
||||
std::vector<std::string> required_files = {
|
||||
"jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8",
|
||||
"idf.utf8", "stop_words.utf8",
|
||||
};
|
||||
|
||||
for (const auto &f : required_files) {
|
||||
if (!FileExists(dict_dir + "/" + f)) {
|
||||
SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir",
|
||||
dict_dir.c_str(), f.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!lexicon.empty() && !FileExists(lexicon)) {
|
||||
SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(rule_fsts, ",", false, &files);
|
||||
|
||||
if (files.size() > 1) {
|
||||
SHERPA_ONNX_LOGE("Only 1 file is supported now.");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
for (const auto &f : files) {
|
||||
if (!FileExists(f)) {
|
||||
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string HomophoneReplacerConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "HomophoneReplacerConfig(";
|
||||
os << "dict_dir=\"" << dict_dir << "\", ";
|
||||
os << "lexicon=\"" << lexicon << "\", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
class HomophoneReplacer::Impl {
|
||||
public:
|
||||
explicit Impl(const HomophoneReplacerConfig &config) : config_(config) {
|
||||
jieba_ = InitJieba(config.dict_dir);
|
||||
|
||||
{
|
||||
std::ifstream is(config.lexicon);
|
||||
InitLexicon(is);
|
||||
}
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
replacer_list_.reserve(files.size());
|
||||
for (const auto &f : files) {
|
||||
if (config.debug) {
|
||||
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
|
||||
}
|
||||
replacer_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) {
|
||||
jieba_ = InitJieba(config.dict_dir);
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.lexicon);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
InitLexicon(is);
|
||||
}
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
replacer_list_.reserve(files.size());
|
||||
for (const auto &f : files) {
|
||||
if (config.debug) {
|
||||
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
|
||||
}
|
||||
auto buf = ReadFile(mgr, f);
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
replacer_list_.push_back(
|
||||
std::make_unique<kaldifst::TextNormalizer>(is));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string Apply(const std::string &text) const {
|
||||
bool is_hmm = true;
|
||||
|
||||
std::vector<std::string> words;
|
||||
jieba_->Cut(text, words, is_hmm);
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str());
|
||||
std::ostringstream os;
|
||||
os << "After jieba: ";
|
||||
std::string sep;
|
||||
for (const auto &w : words) {
|
||||
os << sep << w;
|
||||
sep = "_";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
// convert words to pronunciations
|
||||
std::vector<std::string> pronunciations;
|
||||
|
||||
for (const auto &w : words) {
|
||||
auto p = ConvertWordToPronunciation(w);
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str());
|
||||
}
|
||||
pronunciations.push_back(std::move(p));
|
||||
}
|
||||
|
||||
std::string ans;
|
||||
for (const auto &r : replacer_list_) {
|
||||
ans = r->Normalize(words, pronunciations);
|
||||
// TODO(fangjun): We support only 1 rule fst at present.
|
||||
break;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string ConvertWordToPronunciation(const std::string &word) const {
|
||||
if (word2pron_.count(word)) {
|
||||
return word2pron_.at(word);
|
||||
}
|
||||
|
||||
if (word.size() <= 3) {
|
||||
// not a Chinese character
|
||||
return word;
|
||||
}
|
||||
|
||||
std::vector<std::string> words = SplitUtf8(word);
|
||||
std::string ans;
|
||||
for (const auto &w : words) {
|
||||
if (word2pron_.count(w)) {
|
||||
ans.append(word2pron_.at(w));
|
||||
} else {
|
||||
ans.append(w);
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
void InitLexicon(std::istream &is) {
|
||||
std::string word;
|
||||
std::string pron;
|
||||
std::string p;
|
||||
|
||||
std::string line;
|
||||
int32_t line_num = 0;
|
||||
int32_t num_warn = 0;
|
||||
while (std::getline(is, line)) {
|
||||
++line_num;
|
||||
std::istringstream iss(line);
|
||||
|
||||
pron.clear();
|
||||
iss >> word;
|
||||
ToLowerCase(&word);
|
||||
|
||||
if (word2pron_.count(word)) {
|
||||
num_warn += 1;
|
||||
if (num_warn < 10) {
|
||||
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
|
||||
word.c_str(), line_num, line.c_str());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
while (iss >> p) {
|
||||
pron.append(std::move(p));
|
||||
}
|
||||
|
||||
if (pron.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Empty pronunciation for word '%s' at line %d:%s. Ignore it.",
|
||||
word.c_str(), line_num, line.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
word2pron_.insert({std::move(word), std::move(pron)});
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
HomophoneReplacerConfig config_;
|
||||
std::unique_ptr<cppjieba::Jieba> jieba_;
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> replacer_list_;
|
||||
std::unordered_map<std::string, std::string> word2pron_;
|
||||
};
|
||||
|
||||
HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
HomophoneReplacer::HomophoneReplacer(Manager *mgr,
|
||||
const HomophoneReplacerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
HomophoneReplacer::~HomophoneReplacer() = default;
|
||||
|
||||
std::string HomophoneReplacer::Apply(const std::string &text) const {
|
||||
return impl_->Apply(text);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template HomophoneReplacer::HomophoneReplacer(
|
||||
AAssetManager *mgr, const HomophoneReplacerConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template HomophoneReplacer::HomophoneReplacer(
|
||||
NativeResourceManager *mgr, const HomophoneReplacerConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
58
sherpa-onnx/csrc/homophone-replacer.h
Normal file
58
sherpa-onnx/csrc/homophone-replacer.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// sherpa-onnx/csrc/homophone-replacer.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
|
||||
#define SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct HomophoneReplacerConfig {
|
||||
std::string dict_dir;
|
||||
std::string lexicon;
|
||||
|
||||
// comma separated fst files, e.g. a.fst,b.fst,c.fst
|
||||
std::string rule_fsts;
|
||||
|
||||
bool debug;
|
||||
|
||||
HomophoneReplacerConfig() = default;
|
||||
|
||||
HomophoneReplacerConfig(const std::string &dict_dir,
|
||||
const std::string &lexicon,
|
||||
const std::string &rule_fsts, bool debug)
|
||||
: dict_dir(dict_dir),
|
||||
lexicon(lexicon),
|
||||
rule_fsts(rule_fsts),
|
||||
debug(debug) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class HomophoneReplacer {
|
||||
public:
|
||||
explicit HomophoneReplacer(const HomophoneReplacerConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
HomophoneReplacer(Manager *mgr, const HomophoneReplacerConfig &config);
|
||||
|
||||
~HomophoneReplacer();
|
||||
|
||||
std::string Apply(const std::string &text) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
|
||||
@@ -19,8 +19,8 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "cppjieba/Jieba.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/jieba.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
@@ -41,20 +41,7 @@ class JiebaLexicon::Impl {
|
||||
Impl(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &dict_dir, bool debug)
|
||||
: debug_(debug) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
jieba_ = InitJieba(dict_dir);
|
||||
|
||||
{
|
||||
std::ifstream is(tokens);
|
||||
@@ -71,20 +58,7 @@ class JiebaLexicon::Impl {
|
||||
Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &dict_dir, bool debug)
|
||||
: debug_(debug) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
jieba_ = InitJieba(dict_dir);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, tokens);
|
||||
|
||||
32
sherpa-onnx/csrc/jieba.cc
Normal file
32
sherpa-onnx/csrc/jieba.cc
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/csrc/jieba.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/jieba.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir) {
|
||||
if (dict_dir.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
return std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf,
|
||||
stop_word);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
18
sherpa-onnx/csrc/jieba.h
Normal file
18
sherpa-onnx/csrc/jieba.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// sherpa-onnx/csrc/jieba.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_JIEBA_H_
|
||||
#define SHERPA_ONNX_CSRC_JIEBA_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cppjieba/Jieba.hpp"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir);
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_JIEBA_H_
|
||||
@@ -22,11 +22,11 @@
|
||||
|
||||
#include <codecvt>
|
||||
|
||||
#include "cppjieba/Jieba.hpp"
|
||||
#include "espeak-ng/speak_lib.h"
|
||||
#include "phoneme_ids.hpp"
|
||||
#include "phonemize.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/jieba.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
@@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl {
|
||||
|
||||
InitLexicon(lexicon);
|
||||
|
||||
InitJieba(dict_dir);
|
||||
jieba_ = InitJieba(dict_dir);
|
||||
|
||||
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
|
||||
}
|
||||
@@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl {
|
||||
InitLexicon(mgr, lexicon);
|
||||
|
||||
// we assume you have copied dict_dir and data_dir from assets to some path
|
||||
InitJieba(dict_dir);
|
||||
jieba_ = InitJieba(dict_dir);
|
||||
|
||||
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
|
||||
}
|
||||
@@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
void InitJieba(const std::string &dict_dir) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineTtsKokoroModelMetaData meta_data_;
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "cppjieba/Jieba.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/jieba.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
@@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl {
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||
: meta_data_(meta_data), debug_(debug) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
jieba_ = InitJieba(dict_dir);
|
||||
|
||||
{
|
||||
std::ifstream is(tokens);
|
||||
@@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl {
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||
: meta_data_(meta_data), debug_(debug) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
jieba_ = InitJieba(dict_dir);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, tokens);
|
||||
|
||||
@@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
@@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
|
||||
auto r = Convert(results[0], symbol_table_);
|
||||
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
|
||||
@@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
OfflineRecognizerImpl::OfflineRecognizerImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config) {
|
||||
// TODO(fangjun): Refactor this function
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
@@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
|
||||
SHERPA_ONNX_LOGE("FST archives loaded!");
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||
!config.hr.rule_fsts.empty()) {
|
||||
auto hr_config = config.hr;
|
||||
hr_config.debug = config.model_config.debug;
|
||||
hr_ = std::make_unique<HomophoneReplacer>(hr_config);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
@@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
|
||||
} // for (; !reader->Done(); reader->Next())
|
||||
} // for (const auto &f : files)
|
||||
} // if (!config.rule_fars.empty())
|
||||
|
||||
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||
!config.hr.rule_fsts.empty()) {
|
||||
auto hr_config = config.hr;
|
||||
hr_config.debug = config.model_config.debug;
|
||||
hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
|
||||
}
|
||||
}
|
||||
|
||||
std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
||||
@@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
||||
return text;
|
||||
}
|
||||
|
||||
std::string OfflineRecognizerImpl::ApplyHomophoneReplacer(
|
||||
std::string text) const {
|
||||
if (hr_) {
|
||||
text = hr_->Apply(text);
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
|
||||
config_ = config;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
@@ -48,12 +49,15 @@ class OfflineRecognizerImpl {
|
||||
|
||||
std::string ApplyInverseTextNormalization(std::string text) const;
|
||||
|
||||
std::string ApplyHomophoneReplacer(std::string text) const;
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
// for inverse text normalization. Used only if
|
||||
// config.rule_fsts is not empty or
|
||||
// config.rule_fars is not empty
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
||||
std::unique_ptr<HomophoneReplacer> hr_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
|
||||
|
||||
auto r = Convert(results[0], symbol_table_);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
} catch (const Ort::Exception &ex) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
|
||||
@@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
||||
auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
|
||||
frame_shift_ms, subsampling_factor);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
@@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
||||
subsampling_factor);
|
||||
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
|
||||
@@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
|
||||
@@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
|
||||
@@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
|
||||
std::string s = sym_table[i];
|
||||
s = ApplyInverseTextNormalization(s);
|
||||
s = ApplyHomophoneReplacer(std::move(s));
|
||||
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
|
||||
@@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
model_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
ctc_fst_decoder_config.Register(po);
|
||||
hr.Register(po);
|
||||
|
||||
po->Register(
|
||||
"decoding-method", &decoding_method,
|
||||
@@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const {
|
||||
}
|
||||
}
|
||||
|
||||
if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
|
||||
!hr.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "blank_penalty=" << blank_penalty << ", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||
os << "rule_fars=\"" << rule_fars << "\")";
|
||||
os << "rule_fars=\"" << rule_fars << "\", ";
|
||||
os << "hr=" << hr.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
@@ -40,6 +41,7 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
// If there are multiple FST archives, they are applied from left to right.
|
||||
std::string rule_fars;
|
||||
HomophoneReplacerConfig hr;
|
||||
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
@@ -52,7 +54,7 @@ struct OfflineRecognizerConfig {
|
||||
const std::string &decoding_method, int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty, const std::string &rule_fsts,
|
||||
const std::string &rule_fars)
|
||||
const std::string &rule_fars, const HomophoneReplacerConfig &hr)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -63,7 +65,8 @@ struct OfflineRecognizerConfig {
|
||||
hotwords_score(hotwords_score),
|
||||
blank_penalty(blank_penalty),
|
||||
rule_fsts(rule_fsts),
|
||||
rule_fars(rule_fars) {}
|
||||
rule_fars(rule_fars),
|
||||
hr(hr) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
auto r =
|
||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(r.text);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
|
||||
SHERPA_ONNX_LOGE("FST archives loaded!");
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||
!config.hr.rule_fsts.empty()) {
|
||||
auto hr_config = config.hr;
|
||||
hr_config.debug = config.model_config.debug;
|
||||
hr_ = std::make_unique<HomophoneReplacer>(hr_config);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
@@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr,
|
||||
} // for (; !reader->Done(); reader->Next())
|
||||
} // for (const auto &f : files)
|
||||
} // if (!config.rule_fars.empty())
|
||||
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
|
||||
!config.hr.rule_fsts.empty()) {
|
||||
auto hr_config = config.hr;
|
||||
hr_config.debug = config.model_config.debug;
|
||||
hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
|
||||
}
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
|
||||
@@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
|
||||
return text;
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerImpl::ApplyHomophoneReplacer(
|
||||
std::string text) const {
|
||||
if (hr_) {
|
||||
text = hr_->Apply(text);
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OnlineRecognizerImpl::OnlineRecognizerImpl(
|
||||
AAssetManager *mgr, const OnlineRecognizerConfig &config);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
@@ -57,6 +58,7 @@ class OnlineRecognizerImpl {
|
||||
virtual void Reset(OnlineStream *s) const = 0;
|
||||
|
||||
std::string ApplyInverseTextNormalization(std::string text) const;
|
||||
std::string ApplyHomophoneReplacer(std::string text) const;
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
@@ -64,6 +66,7 @@ class OnlineRecognizerImpl {
|
||||
// config.rule_fsts is not empty or
|
||||
// config.rule_fars is not empty
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
||||
std::unique_ptr<HomophoneReplacer> hr_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
|
||||
auto decoder_result = s->GetParaformerResult();
|
||||
|
||||
auto r = Convert(decoder_result, sym_);
|
||||
r.text = ApplyInverseTextNormalization(r.text);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
// (the encoder state buffers are kept)
|
||||
for (const auto &it : last_result.hyps) {
|
||||
auto h = it.second;
|
||||
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
|
||||
h.ys.end()),
|
||||
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, h.ys.end()),
|
||||
h.log_prob});
|
||||
}
|
||||
|
||||
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
|
||||
last_result.tokens.end());
|
||||
r.tokens = std::vector<int64_t>(last_result.tokens.end() - context_size,
|
||||
last_result.tokens.end());
|
||||
} else {
|
||||
if(config_.reset_encoder) {
|
||||
if (config_.reset_encoder) {
|
||||
// reset encoder states, use blanks as 'ys' context
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
@@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
subsampling_factor, s->GetCurrentSegment(),
|
||||
s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
endpoint_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
ctc_fst_decoder_config.Register(po);
|
||||
hr.Register(po);
|
||||
|
||||
po->Register("enable-endpoint", &enable_endpoint,
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
@@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const {
|
||||
}
|
||||
}
|
||||
|
||||
if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
|
||||
!hr.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "temperature_scale=" << temperature_scale << ", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||
os << "rule_fars=\"" << rule_fars << "\", ";
|
||||
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
|
||||
os << "reset_encoder=" << (reset_encoder ? "True" : "False") << ", ";
|
||||
os << "hr=" << hr.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
@@ -107,6 +108,8 @@ struct OnlineRecognizerConfig {
|
||||
// currently only in `OnlineRecognizerTransducerImpl`.
|
||||
bool reset_encoder = false;
|
||||
|
||||
HomophoneReplacerConfig hr;
|
||||
|
||||
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
||||
/// the hotwords will be loaded from the buffered string instead of from the
|
||||
/// "hotwords_file"
|
||||
@@ -123,7 +126,7 @@ struct OnlineRecognizerConfig {
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty, float temperature_scale,
|
||||
const std::string &rule_fsts, const std::string &rule_fars,
|
||||
bool reset_encoder)
|
||||
bool reset_encoder, const HomophoneReplacerConfig &hr)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -138,7 +141,8 @@ struct OnlineRecognizerConfig {
|
||||
temperature_scale(temperature_scale),
|
||||
rule_fsts(rule_fsts),
|
||||
rule_fars(rule_fars),
|
||||
reset_encoder(reset_encoder) {}
|
||||
reset_encoder(reset_encoder),
|
||||
hr(hr) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
|
||||
auto r =
|
||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(r.text);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
|
||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ set(srcs
|
||||
display.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
homophone-replacer.cc
|
||||
keyword-spotter.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-dolphin-model-config.cc
|
||||
|
||||
28
sherpa-onnx/python/csrc/homophone-replacer.cc
Normal file
28
sherpa-onnx/python/csrc/homophone-replacer.cc
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/python/csrc/homophone-replacer.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/homophone-replacer.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindHomophoneReplacer(py::module *m) {
|
||||
using PyClass = HomophoneReplacerConfig;
|
||||
py::class_<PyClass>(*m, "HomophoneReplacerConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, bool>(),
|
||||
py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"),
|
||||
py::arg("debug") = false)
|
||||
.def_readwrite("dict_dir", &PyClass::dict_dir)
|
||||
.def_readwrite("lexicon", &PyClass::lexicon)
|
||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/homophone-replacer.h
Normal file
16
sherpa-onnx/python/csrc/homophone-replacer.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/homophone-replacer.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindHomophoneReplacer(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
|
||||
@@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
|
||||
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float, const std::string &, const std::string &>(),
|
||||
float, const std::string &, const std::string &,
|
||||
const HomophoneReplacerConfig &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
|
||||
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
|
||||
py::arg("rule_fsts") = "", py::arg("rule_fars") = "",
|
||||
py::arg("hr") = HomophoneReplacerConfig{})
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
@@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
||||
.def_readwrite("hr", &PyClass::hr)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
const OnlineLMConfig &, const EndpointConfig &,
|
||||
const OnlineCtcFstDecoderConfig &, bool,
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float, float, const std::string &, const std::string &, bool>(),
|
||||
float, float, const std::string &, const std::string &,
|
||||
bool, const HomophoneReplacerConfig &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(),
|
||||
py::arg("endpoint_config") = EndpointConfig(),
|
||||
@@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
|
||||
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
|
||||
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
|
||||
py::arg("rule_fars") = "", py::arg("reset_encoder") = false,
|
||||
py::arg("hr") = HomophoneReplacerConfig{})
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
@@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
||||
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
|
||||
.def_readwrite("hr", &PyClass::hr)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
|
||||
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||
@@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindAudioTagging(&m);
|
||||
PybindOfflinePunctuation(&m);
|
||||
PybindOnlinePunctuation(&m);
|
||||
PybindHomophoneReplacer(&m);
|
||||
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineCtcFstDecoderConfig(&m);
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
FeatureExtractorConfig,
|
||||
HomophoneReplacerConfig,
|
||||
OfflineCtcFstDecoderConfig,
|
||||
OfflineDolphinModelConfig,
|
||||
OfflineFireRedAsrModelConfig,
|
||||
@@ -64,6 +65,9 @@ class OfflineRecognizer(object):
|
||||
rule_fars: str = "",
|
||||
lm: str = "",
|
||||
lm_scale: float = 0.1,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -181,6 +185,11 @@ class OfflineRecognizer(object):
|
||||
blank_penalty=blank_penalty,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -201,6 +210,9 @@ class OfflineRecognizer(object):
|
||||
use_itn: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -263,6 +275,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -281,6 +298,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -336,6 +356,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -354,6 +379,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -411,6 +439,9 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -429,6 +460,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -483,6 +517,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -501,6 +540,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -557,6 +599,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -577,6 +624,9 @@ class OfflineRecognizer(object):
|
||||
tail_paddings: int = -1,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -647,6 +697,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -664,6 +719,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -719,6 +777,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -738,6 +801,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -800,6 +866,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -818,6 +889,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -873,6 +947,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -891,6 +970,9 @@ class OfflineRecognizer(object):
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -947,6 +1029,11 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
|
||||
@@ -3,25 +3,26 @@ from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
CudaConfig,
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
HomophoneReplacerConfig,
|
||||
OnlineCtcFstDecoderConfig,
|
||||
OnlineLMConfig,
|
||||
OnlineModelConfig,
|
||||
OnlineNeMoCtcModelConfig,
|
||||
OnlineParaformerModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||
from _sherpa_onnx import (
|
||||
CudaConfig,
|
||||
TensorrtConfig,
|
||||
ProviderConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineRecognizerResult,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
OnlineWenetCtcModelConfig,
|
||||
OnlineNeMoCtcModelConfig,
|
||||
OnlineZipformer2CtcModelConfig,
|
||||
OnlineCtcFstDecoderConfig,
|
||||
ProviderConfig,
|
||||
TensorrtConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -82,9 +83,12 @@ class OnlineRecognizer(object):
|
||||
trt_detailed_build_log: bool = False,
|
||||
trt_engine_cache_enable: bool = True,
|
||||
trt_timing_cache_enable: bool = True,
|
||||
trt_engine_cache_path: str ="",
|
||||
trt_timing_cache_path: str ="",
|
||||
trt_engine_cache_path: str = "",
|
||||
trt_timing_cache_path: str = "",
|
||||
trt_dump_subgraphs: bool = False,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -228,27 +232,27 @@ class OnlineRecognizer(object):
|
||||
)
|
||||
|
||||
cuda_config = CudaConfig(
|
||||
cudnn_conv_algo_search=cudnn_conv_algo_search,
|
||||
cudnn_conv_algo_search=cudnn_conv_algo_search,
|
||||
)
|
||||
|
||||
trt_config = TensorrtConfig(
|
||||
trt_max_workspace_size=trt_max_workspace_size,
|
||||
trt_max_partition_iterations=trt_max_partition_iterations,
|
||||
trt_min_subgraph_size=trt_min_subgraph_size,
|
||||
trt_fp16_enable=trt_fp16_enable,
|
||||
trt_detailed_build_log=trt_detailed_build_log,
|
||||
trt_engine_cache_enable=trt_engine_cache_enable,
|
||||
trt_timing_cache_enable=trt_timing_cache_enable,
|
||||
trt_engine_cache_path=trt_engine_cache_path,
|
||||
trt_timing_cache_path=trt_timing_cache_path,
|
||||
trt_dump_subgraphs=trt_dump_subgraphs,
|
||||
trt_max_workspace_size=trt_max_workspace_size,
|
||||
trt_max_partition_iterations=trt_max_partition_iterations,
|
||||
trt_min_subgraph_size=trt_min_subgraph_size,
|
||||
trt_fp16_enable=trt_fp16_enable,
|
||||
trt_detailed_build_log=trt_detailed_build_log,
|
||||
trt_engine_cache_enable=trt_engine_cache_enable,
|
||||
trt_timing_cache_enable=trt_timing_cache_enable,
|
||||
trt_engine_cache_path=trt_engine_cache_path,
|
||||
trt_timing_cache_path=trt_timing_cache_path,
|
||||
trt_dump_subgraphs=trt_dump_subgraphs,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
trt_config=trt_config,
|
||||
cuda_config=cuda_config,
|
||||
provider=provider,
|
||||
device=device,
|
||||
trt_config=trt_config,
|
||||
cuda_config=cuda_config,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
@@ -311,6 +315,11 @@ class OnlineRecognizer(object):
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
reset_encoder=reset_encoder,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -336,6 +345,9 @@ class OnlineRecognizer(object):
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -402,8 +414,8 @@ class OnlineRecognizer(object):
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
@@ -434,6 +446,11 @@ class OnlineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -460,6 +477,9 @@ class OnlineRecognizer(object):
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -526,8 +546,8 @@ class OnlineRecognizer(object):
|
||||
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
@@ -563,6 +583,11 @@ class OnlineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -587,6 +612,9 @@ class OnlineRecognizer(object):
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -650,8 +678,8 @@ class OnlineRecognizer(object):
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
@@ -681,6 +709,11 @@ class OnlineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -707,6 +740,9 @@ class OnlineRecognizer(object):
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -775,8 +811,8 @@ class OnlineRecognizer(object):
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
@@ -806,6 +842,11 @@ class OnlineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
|
||||
Reference in New Issue
Block a user