diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 13357f79..c439319e 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config_.model_config)) { + if (symbol_table_.Contains("")) { + unk_id_ = symbol_table_[""]; + } + if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), config_.blank_penalty); + model_.get(), unk_id_, config_.blank_penalty); } else if (config_.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { lm_ = OfflineLM::Create(config.lm_config); @@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, config_.blank_penalty); + config_.lm_config.scale, unk_id_, config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config_.decoding_method.c_str()); @@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { symbol_table_(mgr, config_.model_config.tokens), model_(std::make_unique(mgr, config_.model_config)) { + if (symbol_table_.Contains("")) { + unk_id_ = symbol_table_[""]; + } + if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), config_.blank_penalty); + model_.get(), unk_id_, config_.blank_penalty); } else if (config_.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { lm_ = OfflineLM::Create(mgr, config.lm_config); @@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, config_.blank_penalty); + config_.lm_config.scale, unk_id_, config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config_.decoding_method.c_str()); @@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::unique_ptr model_; std::unique_ptr decoder_; std::unique_ptr lm_; + int32_t unk_id_ = -1; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc index c8809a9f..6fd3bf40 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc @@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); p_logit += vocab_size; - if (y != 0) { + // blank id is hardcoded to 0 + // also, it treats unk as blank + if (y != 0 && y != unk_id_) { ans[i].tokens.push_back(y); ans[i].timestamps.push_back(t); emitted = true; diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h index b284d22a..79109e60 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h @@ -15,8 +15,9 @@ namespace sherpa_onnx { class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { public: OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, + int32_t unk_id, float blank_penalty) - : model_(model), blank_penalty_(blank_penalty) {} + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} std::vector Decode( Ort::Value encoder_out, Ort::Value encoder_out_length, @@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { private: OfflineTransducerModel *model_; // Not owned + int32_t unk_id_; float blank_penalty_; }; diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc index 391620a0..7e81624e 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( float context_score = 0; auto context_state = new_hyp.context_state; - if (new_token != 0) { - // blank id is fixed to 0 + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { new_hyp.ys.push_back(new_token); new_hyp.timestamps.push_back(t); if (context_graphs[i] != nullptr) { diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h index 08fa4182..2e67cd71 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h @@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, OfflineLM *lm, int32_t max_active_paths, - float lm_scale, + float lm_scale, int32_t unk_id, float blank_penalty) : model_(model), lm_(lm), max_active_paths_(max_active_paths), lm_scale_(lm_scale), + unk_id_(unk_id), blank_penalty_(blank_penalty) {} std::vector Decode( @@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder int32_t max_active_paths_; float lm_scale_; // used only when lm_ is not nullptr + int32_t unk_id_; float blank_penalty_; };