Add inverse text normalization for online ASR (#1020)

This commit is contained in:
Fangjun Kuang
2024-06-17 18:39:23 +08:00
committed by GitHub
parent 6e09933d99
commit 349d957da2
12 changed files with 390 additions and 32 deletions

View File

@@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
: OnlineRecognizerImpl(config),
config_(config),
symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(
@@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config),
: OnlineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>(
@@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = model_->SubsamplingFactor();
return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart());
auto r = Convert(s->GetResult(), symbol_table_, frame_shift_ms,
subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
return r;
}
bool IsEndpoint(OnlineStream *s) const override {