Add inverse text normalization for non-streaming ASR (#1017)
This commit is contained in:
13
.github/scripts/test-python.sh
vendored
13
.github/scripts/test-python.sh
vendored
@@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then
|
||||
python3 ./python-api-examples/online-decode-files.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
@@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
@@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||
|
||||
ln -s $repo $PWD/
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
|
||||
|
||||
python3 ./python-api-examples/inverse-text-normalization-offline-asr.py
|
||||
|
||||
rm -rfv sherpa-onnx-paraformer-zh-2023-03-28
|
||||
|
||||
rm -rf $repo
|
||||
fi
|
||||
|
||||
|
||||
81
python-api-examples/inverse-text-normalization-offline-asr.py
Executable file
81
python-api-examples/inverse-text-normalization-offline-asr.py
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
This script shows how to use inverse text normalization with non-streaming ASR.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Download the test model
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
|
||||
rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
|
||||
|
||||
(2) Download rule fst
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
|
||||
|
||||
Please refer to
|
||||
https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb
|
||||
for how itn_zh_number.fst is generated.
|
||||
|
||||
(3) Download test wave
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
|
||||
|
||||
(4) Run this script
|
||||
|
||||
python3 ./python-api-examples/inverse-text-normalization-offline-asr.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
model = "./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
|
||||
tokens = "./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
|
||||
rule_fsts = "./itn_zh_number.fst"
|
||||
|
||||
if (
|
||||
not Path(model).is_file()
|
||||
or not Path(tokens).is_file()
|
||||
or not Path(rule_fsts).is_file()
|
||||
):
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=model,
|
||||
tokens=tokens,
|
||||
debug=True,
|
||||
rule_fsts=rule_fsts,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer = create_recognizer()
|
||||
wave_filename = "./itn-zh-number.wav"
|
||||
if not Path(wave_filename).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
print(wave_filename)
|
||||
print(stream.result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(OfflineCtcModel::Create(config_.model_config)) {
|
||||
Init();
|
||||
@@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerCtcImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
|
||||
Init();
|
||||
@@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
@@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
|
||||
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,18 @@
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "fst/extensions/far/far.h"
|
||||
#include "kaldifst/csrc/kaldi-fst-io.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
@@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
}
|
||||
#endif
|
||||
|
||||
OfflineRecognizerImpl::OfflineRecognizerImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config) {
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
itn_list_.reserve(files.size());
|
||||
for (const auto &f : files) {
|
||||
if (config.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
|
||||
}
|
||||
itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.rule_fars.empty()) {
|
||||
if (config.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("Loading FST archives");
|
||||
}
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fars, ",", false, &files);
|
||||
|
||||
itn_list_.reserve(files.size() + itn_list_.size());
|
||||
|
||||
for (const auto &f : files) {
|
||||
if (config.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
|
||||
}
|
||||
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
|
||||
fst::FarReader<fst::StdArc>::Open(f));
|
||||
for (; !reader->Done(); reader->Next()) {
|
||||
std::unique_ptr<fst::StdConstFst> r(
|
||||
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
|
||||
|
||||
itn_list_.push_back(
|
||||
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
|
||||
}
|
||||
}
|
||||
|
||||
if (config.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("FST archives loaded!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerImpl::OfflineRecognizerImpl(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config)
|
||||
: config_(config) {
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
itn_list_.reserve(files.size());
|
||||
for (const auto &f : files) {
|
||||
if (config.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
|
||||
}
|
||||
auto buf = ReadFile(mgr, f);
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.rule_fars.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fars, ",", false, &files);
|
||||
itn_list_.reserve(files.size() + itn_list_.size());
|
||||
|
||||
for (const auto &f : files) {
|
||||
if (config.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
|
||||
}
|
||||
|
||||
auto buf = ReadFile(mgr, f);
|
||||
|
||||
std::unique_ptr<std::istream> s(
|
||||
new std::istrstream(buf.data(), buf.size()));
|
||||
|
||||
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
|
||||
fst::FarReader<fst::StdArc>::Open(std::move(s)));
|
||||
|
||||
for (; !reader->Done(); reader->Next()) {
|
||||
std::unique_ptr<fst::StdConstFst> r(
|
||||
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
|
||||
|
||||
itn_list_.push_back(
|
||||
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
|
||||
} // for (; !reader->Done(); reader->Next())
|
||||
} // for (const auto &f : files)
|
||||
} // if (!config.rule_fars.empty())
|
||||
}
|
||||
#endif
|
||||
|
||||
std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
|
||||
std::string text) const {
|
||||
if (!itn_list_.empty()) {
|
||||
for (const auto &tn : itn_list_) {
|
||||
text = tn->Normalize(text);
|
||||
if (config_.model_config.debug) {
|
||||
SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
@@ -22,10 +23,15 @@ namespace sherpa_onnx {
|
||||
|
||||
class OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerImpl(const OfflineRecognizerConfig &config);
|
||||
|
||||
static std::unique_ptr<OfflineRecognizerImpl> Create(
|
||||
const OfflineRecognizerConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config);
|
||||
|
||||
static std::unique_ptr<OfflineRecognizerImpl> Create(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config);
|
||||
#endif
|
||||
@@ -41,6 +47,15 @@ class OfflineRecognizerImpl {
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
|
||||
|
||||
std::string ApplyInverseTextNormalization(std::string text) const;
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
// for inverse text normalization. Used only if
|
||||
// config.rule_fsts is not empty or
|
||||
// config.rule_fars is not empty
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerParaformerImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineParaformerModel>(config.model_config)) {
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
@@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerParaformerImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineParaformerModel>(mgr,
|
||||
config.model_config)) {
|
||||
@@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerTransducerImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
@@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit OfflineRecognizerTransducerImpl(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(mgr,
|
||||
config_.model_config)) {
|
||||
@@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
|
||||
@@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerTransducerNeMoImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerNeMoModel>(
|
||||
config_.model_config)) {
|
||||
@@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit OfflineRecognizerTransducerNeMoImpl(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerNeMoModel>(
|
||||
mgr, config_.model_config)) {
|
||||
@@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
|
||||
Init();
|
||||
@@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerWhisperImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(
|
||||
std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
|
||||
@@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
std::move(cross_kv.second));
|
||||
|
||||
auto r = Convert(results[0], symbol_table_);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
} catch (const Ort::Exception &ex) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
@@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
po->Register("hotwords-score", &hotwords_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
|
||||
po->Register(
|
||||
"rule-fsts", &rule_fsts,
|
||||
"If not empty, it specifies fsts for inverse text normalization. "
|
||||
"If there are multiple fsts, they are separated by a comma.");
|
||||
|
||||
po->Register(
|
||||
"rule-fars", &rule_fars,
|
||||
"If not empty, it specifies fst archives for inverse text normalization. "
|
||||
"If there are multiple archives, they are separated by a comma.");
|
||||
}
|
||||
|
||||
bool OfflineRecognizerConfig::Validate() const {
|
||||
@@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const {
|
||||
if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please use --decoding-method=modified_beam_search if you"
|
||||
" provide --hotwords-file. Given --decoding-method=%s",
|
||||
" provide --hotwords-file. Given --decoding-method='%s'",
|
||||
decoding_method.c_str());
|
||||
return false;
|
||||
}
|
||||
@@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hotwords_file.empty() && !FileExists(hotwords_file)) {
|
||||
SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist",
|
||||
hotwords_file.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(rule_fsts, ",", false, &files);
|
||||
for (const auto &f : files) {
|
||||
if (!FileExists(f)) {
|
||||
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!rule_fars.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(rule_fars, ",", false, &files);
|
||||
for (const auto &f : files) {
|
||||
if (!FileExists(f)) {
|
||||
SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "blank_penalty=" << blank_penalty << ")";
|
||||
os << "blank_penalty=" << blank_penalty << ", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||
os << "rule_fars=\"" << rule_fars << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -40,6 +40,12 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
float blank_penalty = 0.0;
|
||||
|
||||
// If there are multiple rules, they are applied from left to right.
|
||||
std::string rule_fsts;
|
||||
|
||||
// If there are multiple FST archives, they are applied from left to right.
|
||||
std::string rule_fars;
|
||||
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
@@ -50,7 +56,8 @@ struct OfflineRecognizerConfig {
|
||||
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
const std::string &decoding_method, int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty)
|
||||
float blank_penalty, const std::string &rule_fsts,
|
||||
const std::string &rule_fars)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -59,7 +66,9 @@ struct OfflineRecognizerConfig {
|
||||
max_active_paths(max_active_paths),
|
||||
hotwords_file(hotwords_file),
|
||||
hotwords_score(hotwords_score),
|
||||
blank_penalty(blank_penalty) {}
|
||||
blank_penalty(blank_penalty),
|
||||
rule_fsts(rule_fsts),
|
||||
rule_fars(rule_fars) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
|
||||
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float>(),
|
||||
float, const std::string &, const std::string &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0)
|
||||
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
|
||||
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
@@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||
.def_readwrite("rule_fars", &PyClass::rule_fars)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,8 @@ class OfflineRecognizer(object):
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
model_type: str = "transducer",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -107,6 +109,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -143,6 +151,8 @@ class OfflineRecognizer(object):
|
||||
hotwords_file=hotwords_file,
|
||||
hotwords_score=hotwords_score,
|
||||
blank_penalty=blank_penalty,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -159,6 +169,8 @@ class OfflineRecognizer(object):
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -186,6 +198,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -206,6 +224,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -222,6 +242,8 @@ class OfflineRecognizer(object):
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -251,6 +273,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -271,6 +299,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -287,6 +317,8 @@ class OfflineRecognizer(object):
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -315,6 +347,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -335,6 +373,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -353,6 +393,8 @@ class OfflineRecognizer(object):
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
tail_paddings: int = -1,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -389,6 +431,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -415,6 +463,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -431,6 +481,8 @@ class OfflineRecognizer(object):
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -458,6 +510,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -478,6 +536,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -494,6 +554,8 @@ class OfflineRecognizer(object):
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -522,6 +584,12 @@ class OfflineRecognizer(object):
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -542,6 +610,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
|
||||
Reference in New Issue
Block a user