Add endpointing (#54)

This commit is contained in:
Fangjun Kuang
2023-02-22 15:35:55 +08:00
committed by GitHub
parent 1c6f79f096
commit 124384369a
23 changed files with 2190 additions and 21 deletions

View File

@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
features.cc
online-transducer-model-config.cc
sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
)

View File

@@ -0,0 +1,100 @@
// sherpa-onnx/csrc/endpoint.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/endpoint.h"
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/endpoint.h"
namespace sherpa_onnx {
static constexpr const char *kEndpointRuleInitDoc = R"doc(
Constructor for EndpointRule.
Args:
must_contain_nonsilence:
If True, for this endpointing rule to apply there must be nonsilence in the
best-path traceback. For decoding, a non-blank token is considered as
non-silence.
min_trailing_silence:
This endpointing rule requires duration of trailing silence (in seconds)
to be ``>=`` this value.
min_utterance_length:
This endpointing rule requires utterance-length (in seconds) to
be ``>=`` this value.
)doc";
static constexpr const char *kEndpointConfigInitDoc = R"doc(
If any rule in EndpointConfig is activated, it is said that an endpointing
is detected.
Args:
rule1:
By default, it times out after 2.4 seconds of silence, even if
we decoded nothing.
rule2:
By default, it times out after 1.2 seconds of silence after decoding
something.
rule3:
By default, it times out after the utterance is 20 seconds long, regardless of
anything else.
)doc";
static void PybindEndpointRule(py::module *m) {
using PyClass = EndpointRule;
py::class_<PyClass>(*m, "EndpointRule")
.def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
kEndpointRuleInitDoc)
.def("__str__", &PyClass::ToString)
.def_readwrite("must_contain_nonsilence",
&PyClass::must_contain_nonsilence)
.def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
.def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
}
static void PybindEndpointConfig(py::module *m) {
using PyClass = EndpointConfig;
py::class_<PyClass>(*m, "EndpointConfig")
.def(
py::init(
[](float rule1_min_trailing_silence,
float rule2_min_trailing_silence,
float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
EndpointRule rule1(false, rule1_min_trailing_silence, 0);
EndpointRule rule2(true, rule2_min_trailing_silence, 0);
EndpointRule rule3(false, 0, rule3_min_utterance_length);
return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
}),
py::arg("rule1_min_trailing_silence"),
py::arg("rule2_min_trailing_silence"),
py::arg("rule3_min_utterance_length"))
.def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
auto ans = std::make_unique<PyClass>();
ans->rule1 = rule1;
ans->rule2 = rule2;
ans->rule3 = rule3;
return ans;
}),
py::arg("rule1") = EndpointRule(false, 2.4, 0),
py::arg("rule2") = EndpointRule(true, 1.2, 0),
py::arg("rule3") = EndpointRule(false, 0, 20),
kEndpointConfigInitDoc)
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def_readwrite("rule1", &PyClass::rule1)
.def_readwrite("rule2", &PyClass::rule2)
.def_readwrite("rule3", &PyClass::rule3);
}
void PybindEndpoint(py::module *m) {
PybindEndpointRule(m);
PybindEndpointConfig(m);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/csrc/endpoint.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindEndpoint(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_

View File

@@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
const OnlineTransducerModelConfig &, const std::string &,
const EndpointConfig &, bool>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def("__str__", &PyClass::ToString);
}
@@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) {
[](PyClass &self, std::vector<OnlineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
})
.def("get_result", &PyClass::GetResult);
.def("get_result", &PyClass::GetResult)
.def("is_endpoint", &PyClass::IsEndpoint)
.def("reset", &PyClass::Reset);
}
} // namespace sherpa_onnx

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
@@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFeatures(&m);
PybindOnlineTransducerModelConfig(&m);
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
}