Add inverse text normalization for online ASR (#1020)

This commit is contained in:
Fangjun Kuang
2024-06-17 18:39:23 +08:00
committed by GitHub
parent 6e09933d99
commit 349d957da2
12 changed files with 390 additions and 32 deletions

View File

@@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config)
: config_(config),
: OnlineRecognizerImpl(config),
config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
@@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: config_(config),
: OnlineRecognizerImpl(mgr, config),
config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
@@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
return r;
}
bool IsEndpoint(OnlineStream *s) const override {