Add inverse text normalization for non-streaming ASR (#1017)

This commit is contained in:
Fangjun Kuang
2024-06-17 14:28:53 +08:00
committed by GitHub
parent dd69a1b56b
commit b0f7ed3ee3
13 changed files with 380 additions and 19 deletions

View File

@@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(OfflineCtcModel::Create(config_.model_config)) {
Init();
@@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
OfflineRecognizerCtcImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
Init();
@@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
ss[i]->SetResult(r);
}
}
@@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
}

View File

@@ -5,7 +5,18 @@
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
@@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
}
#endif
OfflineRecognizerImpl::OfflineRecognizerImpl(
const OfflineRecognizerConfig &config)
: config_(config) {
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
itn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
if (!config.rule_fars.empty()) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("Loading FST archives");
}
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
itn_list_.reserve(files.size() + itn_list_.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
}
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(f));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
itn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
}
}
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("FST archives loaded!");
}
}
}
#if __ANDROID_API__ >= 9
OfflineRecognizerImpl::OfflineRecognizerImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config) {
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
itn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
}
}
if (!config.rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
itn_list_.reserve(files.size() + itn_list_.size());
for (const auto &f : files) {
if (config.model_config.debug) {
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::unique_ptr<std::istream> s(
new std::istrstream(buf.data(), buf.size()));
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(std::move(s)));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
itn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
} // for (; !reader->Done(); reader->Next())
} // for (const auto &f : files)
} // if (!config.rule_fars.empty())
}
#endif
std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
std::string text) const {
if (!itn_list_.empty()) {
for (const auto &tn : itn_list_) {
text = tn->Normalize(text);
if (config_.model_config.debug) {
SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
}
}
}
return text;
}
} // namespace sherpa_onnx

View File

@@ -14,6 +14,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
@@ -22,10 +23,15 @@ namespace sherpa_onnx {
class OfflineRecognizerImpl {
public:
explicit OfflineRecognizerImpl(const OfflineRecognizerConfig &config);
static std::unique_ptr<OfflineRecognizerImpl> Create(
const OfflineRecognizerConfig &config);
#if __ANDROID_API__ >= 9
OfflineRecognizerImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config);
static std::unique_ptr<OfflineRecognizerImpl> Create(
AAssetManager *mgr, const OfflineRecognizerConfig &config);
#endif
@@ -41,6 +47,15 @@ class OfflineRecognizerImpl {
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
std::string ApplyInverseTextNormalization(std::string text) const;
private:
OfflineRecognizerConfig config_;
// for inverse text normalization. Used only if
// config.rule_fsts is not empty or
// config.rule_fars is not empty
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
};
} // namespace sherpa_onnx

View File

@@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerParaformerImpl(
const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineParaformerModel>(config.model_config)) {
if (config.decoding_method == "greedy_search") {
@@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
OfflineRecognizerParaformerImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineParaformerModel>(mgr,
config.model_config)) {
@@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
ss[i]->SetResult(r);
}
}

View File

@@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerTransducerImpl(
const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
@@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit OfflineRecognizerTransducerImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(mgr,
config_.model_config)) {
@@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
ss[i]->SetResult(r);
}

View File

@@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerTransducerNeMoImpl(
const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerNeMoModel>(
config_.model_config)) {
@@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit OfflineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerNeMoModel>(
mgr, config_.model_config)) {
@@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
ss[i]->SetResult(r);
}

View File

@@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
Init();
@@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
OfflineRecognizerWhisperImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
@@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
std::move(cross_kv.second));
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(

View File

@@ -10,7 +10,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void OfflineRecognizerConfig::Register(ParseOptions *po) {
@@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register(
"rule-fsts", &rule_fsts,
"If not empty, it specifies fsts for inverse text normalization. "
"If there are multiple fsts, they are separated by a comma.");
po->Register(
"rule-fars", &rule_fars,
"If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma.");
}
bool OfflineRecognizerConfig::Validate() const {
@@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const {
if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
SHERPA_ONNX_LOGE(
"Please use --decoding-method=modified_beam_search if you"
" provide --hotwords-file. Given --decoding-method=%s",
" provide --hotwords-file. Given --decoding-method='%s'",
decoding_method.c_str());
return false;
}
@@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const {
return false;
}
if (!hotwords_file.empty() && !FileExists(hotwords_file)) {
SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist",
hotwords_file.c_str());
return false;
}
if (!rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fsts, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
return false;
}
}
}
if (!rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fars, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
return false;
}
}
}
return model_config.Validate();
}
@@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "blank_penalty=" << blank_penalty << ")";
os << "blank_penalty=" << blank_penalty << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
return os.str();
}

View File

@@ -40,6 +40,12 @@ struct OfflineRecognizerConfig {
float blank_penalty = 0.0;
// If there are multiple rules, they are applied from left to right.
std::string rule_fsts;
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
@@ -50,7 +56,8 @@ struct OfflineRecognizerConfig {
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score,
float blank_penalty)
float blank_penalty, const std::string &rule_fsts,
const std::string &rule_fars)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
@@ -59,7 +66,9 @@ struct OfflineRecognizerConfig {
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
hotwords_score(hotwords_score),
blank_penalty(blank_penalty) {}
blank_penalty(blank_penalty),
rule_fsts(rule_fsts),
rule_fars(rule_fars) {}
void Register(ParseOptions *po);
bool Validate() const;