diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 0d8febb4..034061de 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser): """, ) +def add_blank_penalty_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + def check_args(args): if not Path(args.tokens).is_file(): @@ -414,6 +427,7 @@ def get_args(): add_feature_config_args(parser) add_decoding_args(parser) add_hotwords_args(parser) + add_blank_penalty_args(parser) parser.add_argument( "--port", @@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: max_active_paths=args.max_active_paths, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, provider=args.provider, ) elif args.paraformer: diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index 78a1af04..a058d843 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -231,6 +231,18 @@ def get_args(): """, ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + parser.add_argument( "--decoding-method", type=str, @@ -335,6 +347,7 @@ def main(): decoding_method=args.decoding_method, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, debug=args.debug, ) elif args.paraformer: diff --git a/python-api-examples/vad-with-non-streaming-asr.py b/python-api-examples/vad-with-non-streaming-asr.py index 4f497438..c67dd8e8 100755 --- a/python-api-examples/vad-with-non-streaming-asr.py +++ b/python-api-examples/vad-with-non-streaming-asr.py @@ -177,6 +177,18 @@ def get_args(): """, ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + parser.add_argument( "--decoding-method", type=str, @@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: sample_rate=args.sample_rate, feature_dim=args.feature_dim, decoding_method=args.decoding_method, + blank_penalty=args.blank_penalty, debug=args.debug, ) elif args.paraformer: diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h index 086b064e..ba01835f 100644 --- a/sherpa-onnx/csrc/math.h +++ b/sherpa-onnx/csrc/math.h @@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { } } +template +void SubtractBlank(T *in, int32_t w, int32_t h, + int32_t blank_idx, float blank_penalty) { + for (int32_t i = 0; i != h; ++i) { + in[blank_idx] -= blank_penalty; + in += w; + } +} + template std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { std::vector vec_index(size); diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index bf30eb30..084d39fe 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (config_.decoding_method == "greedy_search") { decoder_ = - std::make_unique(model_.get()); + std::make_unique( + model_.get(), config_.blank_penalty); } else if (config_.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { lm_ = OfflineLM::Create(config.lm_config); @@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); + config_.lm_config.scale, config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config_.decoding_method.c_str()); @@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { config_.model_config)) { if (config_.decoding_method == "greedy_search") { decoder_ = - std::make_unique(model_.get()); + std::make_unique( + model_.get(), config_.blank_penalty); } else if (config_.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { lm_ = OfflineLM::Create(mgr, config.lm_config); @@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); + config_.lm_config.scale, config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config_.decoding_method.c_str()); diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 7ab7849c..5c10eb3a 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { po->Register("max-active-paths", &max_active_paths, "Used only when decoding_method is modified_beam_search"); + po->Register("blank-penalty", &blank_penalty, + "The penalty applied on blank symbol during decoding. " + "Note: It is a positive value. " + "Increasing value will lead to lower deletion at the cost" + "of higher insertions. " + "Currently only applicable for transducer models."); + po->Register( "hotwords-file", &hotwords_file, "The file containing hotwords, one words/phrases per line, and for each" @@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const { os << "decoding_method=\"" << decoding_method << "\", "; os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; - os << "hotwords_score=" << hotwords_score << ")"; + os << "hotwords_score=" << hotwords_score << ", "; + os << "blank_penalty=" << blank_penalty << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 16d2aa92..1a878d63 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -37,6 +37,8 @@ struct OfflineRecognizerConfig { std::string hotwords_file; float hotwords_score = 1.5; + float blank_penalty = 0.0; + // only greedy_search is implemented // TODO(fangjun): Implement modified_beam_search @@ -46,7 +48,8 @@ struct OfflineRecognizerConfig { const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, const std::string &decoding_method, int32_t max_active_paths, - const std::string &hotwords_file, float hotwords_score) + const std::string &hotwords_file, float hotwords_score, + float blank_penalty) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -54,7 +57,8 @@ struct OfflineRecognizerConfig { decoding_method(decoding_method), max_active_paths(max_active_paths), hotwords_file(hotwords_file), - hotwords_score(hotwords_score) {} + hotwords_score(hotwords_score), + blank_penalty(blank_penalty) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc index 99ac3338..c8809a9f 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc @@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, start += n; Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), std::move(cur_decoder_out)); - const float *p_logit = logit.GetTensorData(); + float *p_logit = logit.GetTensorMutableData(); bool emitted = false; for (int32_t i = 0; i != n; ++i) { + if (blank_penalty_ > 0.0) { + p_logit[0] -= blank_penalty_; // assuming blank id is 0 + } auto y = static_cast(std::distance( static_cast(p_logit), std::max_element(static_cast(p_logit), diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h index ff172250..f90ce911 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h @@ -14,8 +14,10 @@ namespace sherpa_onnx { class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { public: - explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model) - : model_(model) {} + explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, + float blank_penalty) + : model_(model), + blank_penalty_(blank_penalty) {} std::vector Decode( Ort::Value encoder_out, Ort::Value encoder_out_length, @@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { private: OfflineTransducerModel *model_; // Not owned + float blank_penalty_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc index 3730ff6b..317c7ad8 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty_ > 0.0) { + // assuming blank id is 0 + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); + } LogSoftmax(p_logit, vocab_size, num_hyps); // now p_logit contains log_softmax output, we rename it to p_logprob diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h index 89b277c6..08fa4182 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h @@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, OfflineLM *lm, int32_t max_active_paths, - float lm_scale) + float lm_scale, + float blank_penalty) : model_(model), lm_(lm), max_active_paths_(max_active_paths), - lm_scale_(lm_scale) {} + lm_scale_(lm_scale), + blank_penalty_(blank_penalty) {} std::vector Decode( Ort::Value encoder_out, Ort::Value encoder_out_length, @@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder int32_t max_active_paths_; float lm_scale_; // used only when lm_ is not nullptr + float blank_penalty_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index e70ffe66..c0ebf7a8 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def(py::init(), + int32_t, const std::string &, float, float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 1.5) + py::arg("hotwords_score") = 1.5, + py::arg("blank_penalty") = 0.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 214f98a1..6a214297 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -48,6 +48,7 @@ class OfflineRecognizer(object): max_active_paths: int = 4, hotwords_file: str = "", hotwords_score: float = 1.5, + blank_penalty: float = 0.0, debug: bool = False, provider: str = "cpu", ): @@ -81,6 +82,8 @@ class OfflineRecognizer(object): max_active_paths: Maximum number of active paths to keep. Used only when decoding_method is modified_beam_search. + blank_penalty: + The penalty applied on blank symbol during decoding. debug: True to show debug messages. provider: @@ -117,6 +120,7 @@ class OfflineRecognizer(object): decoding_method=decoding_method, hotwords_file=hotwords_file, hotwords_score=hotwords_score, + blank_penalty=blank_penalty, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config