Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -1,12 +1,13 @@
include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
)
if(APPLE)

View File

@@ -0,0 +1,18 @@
// sherpa-onnx/python/csrc/display.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/csrc/display.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m) {
using PyClass = Display;
py::class_<PyClass>(*m, "Display")
.def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
.def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/display.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_

View File

@@ -11,10 +11,12 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80)
.def(py::init<float, int32_t, int32_t>(),
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
py::arg("max_feature_vectors") = -1)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("__str__", &PyClass::ToString);
}

View File

@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const EndpointConfig &,
bool>(),
bool, const std::string &, int32_t>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
py::arg("endpoint_config"), py::arg("enable_endpoint"),
py::arg("decoding_method"), py::arg("max_active_paths"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def("__str__", &PyClass::ToString);
}

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
PybindDisplay(&m);
}
} // namespace sherpa_onnx