Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
pybind11_add_module(_sherpa_onnx
|
||||
display.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-transducer-model-config.cc
|
||||
sherpa-onnx.cc
|
||||
endpoint.cc
|
||||
online-stream.cc
|
||||
online-recognizer.cc
|
||||
)
|
||||
|
||||
if(APPLE)
|
||||
|
||||
18
sherpa-onnx/python/csrc/display.cc
Normal file
18
sherpa-onnx/python/csrc/display.cc
Normal file
@@ -0,0 +1,18 @@
|
||||
// sherpa-onnx/python/csrc/display.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/display.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindDisplay(py::module *m) {
|
||||
using PyClass = Display;
|
||||
py::class_<PyClass>(*m, "Display")
|
||||
.def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
|
||||
.def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/display.h
Normal file
16
sherpa-onnx/python/csrc/display.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/display.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindDisplay(py::module *m);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
|
||||
@@ -11,10 +11,12 @@ namespace sherpa_onnx {
|
||||
static void PybindFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = FeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
||||
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80)
|
||||
.def(py::init<float, int32_t, int32_t>(),
|
||||
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
|
||||
py::arg("max_feature_vectors") = -1)
|
||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const EndpointConfig &,
|
||||
bool>(),
|
||||
bool, const std::string &, int32_t>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"))
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"),
|
||||
py::arg("decoding_method"), py::arg("max_active_paths"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOnlineStream(&m);
|
||||
PybindEndpoint(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
|
||||
PybindDisplay(&m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import Display
|
||||
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
|
||||
@@ -32,6 +32,9 @@ class OnlineRecognizer(object):
|
||||
rule1_min_trailing_silence: int = 2.4,
|
||||
rule2_min_trailing_silence: int = 1.2,
|
||||
rule3_min_utterance_length: int = 20,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
max_feature_vectors: int = -1,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -74,6 +77,14 @@ class OnlineRecognizer(object):
|
||||
Used only when enable_endpoint_detection is True. If the utterance
|
||||
length in seconds is larger than this value, we assume an endpoint
|
||||
is detected.
|
||||
decoding_method:
|
||||
Valid values are greedy_search, modified_beam_search.
|
||||
max_active_paths:
|
||||
Use only when decoding_method is modified_beam_search. It specifies
|
||||
the maximum number of active paths during beam search.
|
||||
max_feature_vectors:
|
||||
Number of feature vectors to cache. -1 means to cache all feature
|
||||
frames that have been processed.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -93,6 +104,7 @@ class OnlineRecognizer(object):
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
max_feature_vectors=max_feature_vectors,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
@@ -106,6 +118,8 @@ class OnlineRecognizer(object):
|
||||
model_config=model_config,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
max_active_paths=max_active_paths,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
|
||||
Reference in New Issue
Block a user