Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)

This commit is contained in:
Fangjun Kuang
2024-05-10 12:15:39 +08:00
committed by GitHub
parent 5d8c35e44e
commit 17cd3a5f01
31 changed files with 1093 additions and 153 deletions

View File

@@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
@@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(),
unk_id_,
config_.blank_penalty,
model_.get(), unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else {
@@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
@@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(),
unk_id_,
config_.blank_penalty,
model_.get(), unk_id_, config_.blank_penalty,
config_.temperature_scale);
} else {