Support replacing homonphonic phrases (#2153)

This commit is contained in:
Fangjun Kuang
2025-04-27 15:31:11 +08:00
committed by GitHub
parent e3280027f9
commit f64c58342b
42 changed files with 834 additions and 134 deletions

View File

@@ -7,6 +7,7 @@ set(srcs
display.cc
endpoint.cc
features.cc
homophone-replacer.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-dolphin-model-config.cc

View File

@@ -0,0 +1,28 @@
// sherpa-onnx/python/csrc/homophone-replacer.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
#include <string>
#include "sherpa-onnx/csrc/homophone-replacer.h"
namespace sherpa_onnx {
void PybindHomophoneReplacer(py::module *m) {
using PyClass = HomophoneReplacerConfig;
py::class_<PyClass>(*m, "HomophoneReplacerConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &, bool>(),
py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"),
py::arg("debug") = false)
.def_readwrite("dict_dir", &PyClass::dict_dir)
.def_readwrite("lexicon", &PyClass::lexicon)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("debug", &PyClass::debug)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/homophone-replacer.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindHomophoneReplacer(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_

View File

@@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
const std::string &, int32_t, const std::string &, float,
float, const std::string &, const std::string &>(),
float, const std::string &, const std::string &,
const HomophoneReplacerConfig &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
py::arg("rule_fsts") = "", py::arg("rule_fars") = "",
py::arg("hr") = HomophoneReplacerConfig{})
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
@@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("hr", &PyClass::hr)
.def("__str__", &PyClass::ToString);
}

View File

@@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool,
const std::string &, int32_t, const std::string &, float,
float, float, const std::string &, const std::string &, bool>(),
float, float, const std::string &, const std::string &,
bool, const HomophoneReplacerConfig &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
@@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
py::arg("rule_fars") = "", py::arg("reset_encoder") = false,
py::arg("hr") = HomophoneReplacerConfig{})
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
@@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
.def_readwrite("hr", &PyClass::hr)
.def("__str__", &PyClass::ToString);
}

View File

@@ -10,6 +10,7 @@
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
@@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindAudioTagging(&m);
PybindOfflinePunctuation(&m);
PybindOnlinePunctuation(&m);
PybindHomophoneReplacer(&m);
PybindFeatures(&m);
PybindOnlineCtcFstDecoderConfig(&m);