Add inverse text normalization for online ASR (#1020)

This commit is contained in:
Fangjun Kuang
2024-06-17 18:39:23 +08:00
committed by GitHub
parent 6e09933d99
commit 349d957da2
12 changed files with 390 additions and 32 deletions

View File

@@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then
$repo/test_wavs/3.wav \ $repo/test_wavs/3.wav \
$repo/test_wavs/8k.wav $repo/test_wavs/8k.wav
ln -s $repo $PWD/
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
python3 ./python-api-examples/inverse-text-normalization-online-asr.py
python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
rm -rfv sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
rm -rf $repo
fi fi
log "Test non-streaming transducer models" log "Test non-streaming transducer models"

View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
#
# Copyright (c) 2024 Xiaomi Corporation
"""
This script shows how to use inverse text normalization with streaming ASR.
Usage:
(1) Download the test model
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
(2) Download rule fst
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
Please refer to
https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb
for how itn_zh_number.fst is generated.
(3) Download test wave
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
(4) Run this script
python3 ./python-api-examples/inverse-text-normalization-online-asr.py
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
encoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
decoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
joiner = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
rule_fsts = "./itn_zh_number.fst"
if (
not Path(encoder).is_file()
or not Path(decoder).is_file()
or not Path(joiner).is_file()
or not Path(tokens).is_file()
or not Path(rule_fsts).is_file()
):
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return sherpa_onnx.OnlineRecognizer.from_transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
debug=True,
rule_fsts=rule_fsts,
)
def main():
recognizer = create_recognizer()
wave_filename = "./itn-zh-number.wav"
if not Path(wave_filename).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
tail_padding = [0] * int(0.3 * sample_rate)
stream.accept_waveform(sample_rate, tail_padding)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
print(wave_filename)
print(recognizer.get_result_all(stream))
if __name__ == "__main__":
main()

View File

@@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
public: public:
explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(config),
config_(config),
model_(OnlineCtcModel::Create(config.model_config)), model_(OnlineCtcModel::Create(config.model_config)),
sym_(config.model_config.tokens), sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
@@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, explicit OnlineRecognizerCtcImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config) const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(mgr, config),
config_(config),
model_(OnlineCtcModel::Create(mgr, config.model_config)), model_(OnlineCtcModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens), sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
@@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed // TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10; int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4; int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart()); s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
return r;
} }
bool IsEndpoint(OnlineStream *s) const override { bool IsEndpoint(OnlineStream *s) const override {

View File

@@ -4,11 +4,22 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-impl.h"
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -78,4 +89,110 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
} }
#endif #endif
OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
: config_(config) {
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
itn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
if (!config.rule_fars.empty()) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("Loading FST archives");
}
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
itn_list_.reserve(files.size() + itn_list_.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
}
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(f));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
itn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
}
}
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("FST archives loaded!");
}
}
}
#if __ANDROID_API__ >= 9
OnlineRecognizerImpl::OnlineRecognizerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: config_(config) {
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
itn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
}
}
if (!config.rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
itn_list_.reserve(files.size() + itn_list_.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::unique_ptr<std::istream> s(
new std::istrstream(buf.data(), buf.size()));
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(std::move(s)));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
itn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
} // for (; !reader->Done(); reader->Next())
} // for (const auto &f : files)
} // if (!config.rule_fars.empty())
}
#endif
std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
std::string text) const {
if (!itn_list_.empty()) {
for (const auto &tn : itn_list_) {
text = tn->Normalize(text);
if (config_.model_config.debug) {
SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
}
}
}
return text;
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -9,6 +9,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-stream.h"
@@ -17,10 +23,15 @@ namespace sherpa_onnx {
class OnlineRecognizerImpl { class OnlineRecognizerImpl {
public: public:
explicit OnlineRecognizerImpl(const OnlineRecognizerConfig &config);
static std::unique_ptr<OnlineRecognizerImpl> Create( static std::unique_ptr<OnlineRecognizerImpl> Create(
const OnlineRecognizerConfig &config); const OnlineRecognizerConfig &config);
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
OnlineRecognizerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config);
static std::unique_ptr<OnlineRecognizerImpl> Create( static std::unique_ptr<OnlineRecognizerImpl> Create(
AAssetManager *mgr, const OnlineRecognizerConfig &config); AAssetManager *mgr, const OnlineRecognizerConfig &config);
#endif #endif
@@ -50,6 +61,15 @@ class OnlineRecognizerImpl {
virtual bool IsEndpoint(OnlineStream *s) const = 0; virtual bool IsEndpoint(OnlineStream *s) const = 0;
virtual void Reset(OnlineStream *s) const = 0; virtual void Reset(OnlineStream *s) const = 0;
std::string ApplyInverseTextNormalization(std::string text) const;
private:
OnlineRecognizerConfig config_;
// for inverse text normalization. Used only if
// config.rule_fsts is not empty or
// config.rule_fars is not empty
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) {
class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
public: public:
explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(config),
config_(config),
model_(config.model_config), model_(config.model_config),
sym_(config.model_config.tokens), sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
@@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config) const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(mgr, config),
config_(config),
model_(mgr, config.model_config), model_(mgr, config.model_config),
sym_(mgr, config.model_config.tokens), sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
@@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
OnlineRecognizerResult GetResult(OnlineStream *s) const override { OnlineRecognizerResult GetResult(OnlineStream *s) const override {
auto decoder_result = s->GetParaformerResult(); auto decoder_result = s->GetParaformerResult();
return Convert(decoder_result, sym_); auto r = Convert(decoder_result, sym_);
r.text = ApplyInverseTextNormalization(r.text);
return r;
} }
bool IsEndpoint(OnlineStream *s) const override { bool IsEndpoint(OnlineStream *s) const override {

View File

@@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
public: public:
explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(config),
config_(config),
model_(OnlineTransducerModel::Create(config.model_config)), model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens), sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
@@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config) const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(mgr, config),
config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)), model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens), sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
@@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed // TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10; int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4; int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart()); s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
return r;
} }
bool IsEndpoint(OnlineStream *s) const override { bool IsEndpoint(OnlineStream *s) const override {

View File

@@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public: public:
explicit OnlineRecognizerTransducerNeMoImpl( explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config) const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(config),
config_(config),
symbol_table_(config.model_config.tokens), symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config), endpoint_(config_.endpoint_config),
model_( model_(
@@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerNeMoImpl( explicit OnlineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OnlineRecognizerConfig &config) AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config), : OnlineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config.model_config.tokens), symbol_table_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config), endpoint_(config_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>( model_(std::make_unique<OnlineTransducerNeMoModel>(
@@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed // TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10; int32_t frame_shift_ms = 10;
int32_t subsampling_factor = model_->SubsamplingFactor(); int32_t subsampling_factor = model_->SubsamplingFactor();
return Convert(s->GetResult(), symbol_table_, frame_shift_ms, auto r = Convert(s->GetResult(), symbol_table_, frame_shift_ms,
subsampling_factor, s->GetCurrentSegment(), subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart()); s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
return r;
} }
bool IsEndpoint(OnlineStream *s) const override { bool IsEndpoint(OnlineStream *s) const override {

View File

@@ -14,7 +14,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"now support greedy_search and modified_beam_search."); "now support greedy_search and modified_beam_search.");
po->Register("temperature-scale", &temperature_scale, po->Register("temperature-scale", &temperature_scale,
"Temperature scale for confidence computation in decoding."); "Temperature scale for confidence computation in decoding.");
po->Register(
"rule-fsts", &rule_fsts,
"If not empty, it specifies fsts for inverse text normalization. "
"If there are multiple fsts, they are separated by a comma.");
po->Register(
"rule-fars", &rule_fars,
"If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma.");
} }
bool OnlineRecognizerConfig::Validate() const { bool OnlineRecognizerConfig::Validate() const {
@@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const {
return false; return false;
} }
if (!hotwords_file.empty() && !FileExists(hotwords_file)) {
SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist",
hotwords_file.c_str());
return false;
}
if (!rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fsts, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
return false;
}
}
}
if (!rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fars, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
return false;
}
}
}
return model_config.Validate(); return model_config.Validate();
} }
@@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "hotwords_file=\"" << hotwords_file << "\", "; os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "decoding_method=\"" << decoding_method << "\", "; os << "decoding_method=\"" << decoding_method << "\", ";
os << "blank_penalty=" << blank_penalty << ", "; os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ")"; os << "temperature_scale=" << temperature_scale << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
return os.str(); return os.str();
} }

View File

@@ -100,6 +100,12 @@ struct OnlineRecognizerConfig {
float temperature_scale = 2.0; float temperature_scale = 2.0;
// If there are multiple rules, they are applied from left to right.
std::string rule_fsts;
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;
OnlineRecognizerConfig() = default; OnlineRecognizerConfig() = default;
OnlineRecognizerConfig( OnlineRecognizerConfig(
@@ -109,7 +115,8 @@ struct OnlineRecognizerConfig {
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method, bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file, int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty, float temperature_scale) float hotwords_score, float blank_penalty, float temperature_scale,
const std::string &rule_fsts, const std::string &rule_fars)
: feat_config(feat_config), : feat_config(feat_config),
model_config(model_config), model_config(model_config),
lm_config(lm_config), lm_config(lm_config),
@@ -121,7 +128,9 @@ struct OnlineRecognizerConfig {
hotwords_file(hotwords_file), hotwords_file(hotwords_file),
hotwords_score(hotwords_score), hotwords_score(hotwords_score),
blank_penalty(blank_penalty), blank_penalty(blank_penalty),
temperature_scale(temperature_scale) {} temperature_scale(temperature_scale),
rule_fsts(rule_fsts),
rule_fars(rule_fars) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;

View File

@@ -54,19 +54,20 @@ static void PybindOnlineRecognizerResult(py::module *m) {
static void PybindOnlineRecognizerConfig(py::module *m) { static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig; using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig") py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def( .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, const OnlineLMConfig &, const EndpointConfig &,
const OnlineLMConfig &, const EndpointConfig &, const OnlineCtcFstDecoderConfig &, bool,
const OnlineCtcFstDecoderConfig &, bool, const std::string &, const std::string &, int32_t, const std::string &, float,
int32_t, const std::string &, float, float, float>(), float, float, const std::string &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(), py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"), py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0) py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "")
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("lm_config", &PyClass::lm_config)
@@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("hotwords_score", &PyClass::hotwords_score) .def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("temperature_scale", &PyClass::temperature_scale) .def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -64,6 +64,8 @@ class OnlineRecognizer(object):
lm_scale: float = 0.1, lm_scale: float = 0.1,
temperature_scale: float = 2.0, temperature_scale: float = 2.0,
debug: bool = False, debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
): ):
""" """
Please refer to Please refer to
@@ -148,6 +150,12 @@ class OnlineRecognizer(object):
the log probability, you can get it from the directory where the log probability, you can get it from the directory where
your bpe model is generated. Only used when hotwords provided your bpe model is generated. Only used when hotwords provided
and the modeling unit is bpe or cjkchar+bpe. and the modeling unit is bpe or cjkchar+bpe.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -217,6 +225,8 @@ class OnlineRecognizer(object):
hotwords_file=hotwords_file, hotwords_file=hotwords_file,
blank_penalty=blank_penalty, blank_penalty=blank_penalty,
temperature_scale=temperature_scale, temperature_scale=temperature_scale,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
@@ -239,6 +249,8 @@ class OnlineRecognizer(object):
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
provider: str = "cpu", provider: str = "cpu",
debug: bool = False, debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
): ):
""" """
Please refer to Please refer to
@@ -283,6 +295,12 @@ class OnlineRecognizer(object):
The only valid value is greedy_search. The only valid value is greedy_search.
provider: provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml. onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -322,6 +340,8 @@ class OnlineRecognizer(object):
endpoint_config=endpoint_config, endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method, decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
@@ -345,6 +365,8 @@ class OnlineRecognizer(object):
ctc_max_active: int = 3000, ctc_max_active: int = 3000,
provider: str = "cpu", provider: str = "cpu",
debug: bool = False, debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
): ):
""" """
Please refer to Please refer to
@@ -393,6 +415,12 @@ class OnlineRecognizer(object):
active paths at a time. active paths at a time.
provider: provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml. onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -433,6 +461,8 @@ class OnlineRecognizer(object):
ctc_fst_decoder_config=ctc_fst_decoder_config, ctc_fst_decoder_config=ctc_fst_decoder_config,
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method, decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
@@ -454,6 +484,8 @@ class OnlineRecognizer(object):
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
provider: str = "cpu", provider: str = "cpu",
debug: bool = False, debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
): ):
""" """
Please refer to Please refer to
@@ -497,6 +529,12 @@ class OnlineRecognizer(object):
onnxruntime execution providers. Valid values are: cpu, cuda, coreml. onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
debug: debug:
True to show meta data in the model. True to show meta data in the model.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -533,6 +571,8 @@ class OnlineRecognizer(object):
endpoint_config=endpoint_config, endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method, decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
@@ -556,6 +596,8 @@ class OnlineRecognizer(object):
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
provider: str = "cpu", provider: str = "cpu",
debug: bool = False, debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
): ):
""" """
Please refer to Please refer to
@@ -602,6 +644,12 @@ class OnlineRecognizer(object):
The only valid value is greedy_search. The only valid value is greedy_search.
provider: provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml. onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -640,6 +688,8 @@ class OnlineRecognizer(object):
endpoint_config=endpoint_config, endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method, decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)