Add inverse text normalization for online ASR (#1020)

This commit is contained in:
Fangjun Kuang
2024-06-17 18:39:23 +08:00
committed by GitHub
parent 6e09933d99
commit 349d957da2
12 changed files with 390 additions and 32 deletions

View File

@@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config)
: config_(config),
: OnlineRecognizerImpl(config),
config_(config),
model_(OnlineCtcModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
@@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit OnlineRecognizerCtcImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: config_(config),
: OnlineRecognizerImpl(mgr, config),
config_(config),
model_(OnlineCtcModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
@@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
return r;
}
bool IsEndpoint(OnlineStream *s) const override {