// sherpa-onnx/python/csrc/keyword-spotter.cc // // Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/python/csrc/keyword-spotter.h" #include #include #include "sherpa-onnx/csrc/keyword-spotter.h" namespace sherpa_onnx { static void PybindKeywordResult(py::module *m) { using PyClass = KeywordResult; py::class_(*m, "KeywordResult") .def_property_readonly( "keyword", [](PyClass &self) -> py::str { return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(), self.keyword.size(), "ignore")); }) .def_property_readonly( "tokens", [](PyClass &self) -> std::vector { return self.tokens; }) .def_property_readonly( "timestamps", [](PyClass &self) -> std::vector { return self.timestamps; }); } static void PybindKeywordSpotterConfig(py::module *m) { using PyClass = KeywordSpotterConfig; py::class_(*m, "KeywordSpotterConfig") .def(py::init(), py::arg("feat_config"), py::arg("model_config"), py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1, py::arg("keywords_score") = 1.0, py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks) .def_readwrite("keywords_score", &PyClass::keywords_score) .def_readwrite("keywords_threshold", &PyClass::keywords_threshold) .def_readwrite("keywords_file", &PyClass::keywords_file) .def("__str__", &PyClass::ToString); } void PybindKeywordSpotter(py::module *m) { PybindKeywordResult(m); PybindKeywordSpotterConfig(m); using PyClass = KeywordSpotter; py::class_(*m, "KeywordSpotter") .def(py::init(), py::arg("config"), py::call_guard()) .def( "create_stream", [](const PyClass &self) { return self.CreateStream(); }, py::call_guard()) .def( "create_stream", [](PyClass &self, const std::string &keywords) { return self.CreateStream(keywords); }, py::arg("keywords"), py::call_guard()) .def("is_ready", &PyClass::IsReady, py::call_guard()) .def("decode_stream", &PyClass::DecodeStream, py::call_guard()) .def( "decode_streams", [](PyClass &self, std::vector ss) { self.DecodeStreams(ss.data(), ss.size()); }, py::call_guard()) .def("get_result", &PyClass::GetResult, py::call_guard()); } } // namespace sherpa_onnx