Add endpointing (#54)
This commit is contained in:
@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
features.cc
|
||||
online-transducer-model-config.cc
|
||||
sherpa-onnx.cc
|
||||
endpoint.cc
|
||||
online-stream.cc
|
||||
online-recognizer.cc
|
||||
)
|
||||
|
||||
100
sherpa-onnx/python/csrc/endpoint.cc
Normal file
100
sherpa-onnx/python/csrc/endpoint.cc
Normal file
@@ -0,0 +1,100 @@
|
||||
// sherpa-onnx/csrc/endpoint.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static constexpr const char *kEndpointRuleInitDoc = R"doc(
|
||||
Constructor for EndpointRule.
|
||||
|
||||
Args:
|
||||
must_contain_nonsilence:
|
||||
If True, for this endpointing rule to apply there must be nonsilence in the
|
||||
best-path traceback. For decoding, a non-blank token is considered as
|
||||
non-silence.
|
||||
min_trailing_silence:
|
||||
This endpointing rule requires duration of trailing silence (in seconds)
|
||||
to be ``>=`` this value.
|
||||
min_utterance_length:
|
||||
This endpointing rule requires utterance-length (in seconds) to
|
||||
be ``>=`` this value.
|
||||
)doc";
|
||||
|
||||
static constexpr const char *kEndpointConfigInitDoc = R"doc(
|
||||
If any rule in EndpointConfig is activated, it is said that an endpointing
|
||||
is detected.
|
||||
|
||||
Args:
|
||||
rule1:
|
||||
By default, it times out after 2.4 seconds of silence, even if
|
||||
we decoded nothing.
|
||||
rule2:
|
||||
By default, it times out after 1.2 seconds of silence after decoding
|
||||
something.
|
||||
rule3:
|
||||
By default, it times out after the utterance is 20 seconds long, regardless of
|
||||
anything else.
|
||||
)doc";
|
||||
|
||||
static void PybindEndpointRule(py::module *m) {
|
||||
using PyClass = EndpointRule;
|
||||
py::class_<PyClass>(*m, "EndpointRule")
|
||||
.def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
|
||||
py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
|
||||
kEndpointRuleInitDoc)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def_readwrite("must_contain_nonsilence",
|
||||
&PyClass::must_contain_nonsilence)
|
||||
.def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
|
||||
.def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
|
||||
}
|
||||
|
||||
static void PybindEndpointConfig(py::module *m) {
|
||||
using PyClass = EndpointConfig;
|
||||
py::class_<PyClass>(*m, "EndpointConfig")
|
||||
.def(
|
||||
py::init(
|
||||
[](float rule1_min_trailing_silence,
|
||||
float rule2_min_trailing_silence,
|
||||
float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
|
||||
EndpointRule rule1(false, rule1_min_trailing_silence, 0);
|
||||
EndpointRule rule2(true, rule2_min_trailing_silence, 0);
|
||||
EndpointRule rule3(false, 0, rule3_min_utterance_length);
|
||||
|
||||
return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
|
||||
}),
|
||||
py::arg("rule1_min_trailing_silence"),
|
||||
py::arg("rule2_min_trailing_silence"),
|
||||
py::arg("rule3_min_utterance_length"))
|
||||
.def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
|
||||
const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
|
||||
auto ans = std::make_unique<PyClass>();
|
||||
ans->rule1 = rule1;
|
||||
ans->rule2 = rule2;
|
||||
ans->rule3 = rule3;
|
||||
return ans;
|
||||
}),
|
||||
py::arg("rule1") = EndpointRule(false, 2.4, 0),
|
||||
py::arg("rule2") = EndpointRule(true, 1.2, 0),
|
||||
py::arg("rule3") = EndpointRule(false, 0, 20),
|
||||
kEndpointConfigInitDoc)
|
||||
.def("__str__",
|
||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||
.def_readwrite("rule1", &PyClass::rule1)
|
||||
.def_readwrite("rule2", &PyClass::rule2)
|
||||
.def_readwrite("rule3", &PyClass::rule3);
|
||||
}
|
||||
|
||||
void PybindEndpoint(py::module *m) {
|
||||
PybindEndpointRule(m);
|
||||
PybindEndpointConfig(m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/endpoint.h
Normal file
16
sherpa-onnx/python/csrc/endpoint.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/csrc/endpoint.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindEndpoint(py::module *m);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
|
||||
@@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const std::string &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
|
||||
const OnlineTransducerModelConfig &, const std::string &,
|
||||
const EndpointConfig &, bool>(),
|
||||
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) {
|
||||
[](PyClass &self, std::vector<OnlineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
})
|
||||
.def("get_result", &PyClass::GetResult);
|
||||
.def("get_result", &PyClass::GetResult)
|
||||
.def("is_endpoint", &PyClass::IsEndpoint)
|
||||
.def("reset", &PyClass::Reset);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
@@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineTransducerModelConfig(&m);
|
||||
PybindOnlineStream(&m);
|
||||
PybindEndpoint(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
|
||||
@@ -2,12 +2,13 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizer as _Recognizer,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
)
|
||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||
|
||||
|
||||
def _assert_file_exists(f: str):
|
||||
@@ -26,6 +27,10 @@ class OnlineRecognizer(object):
|
||||
num_threads: int = 4,
|
||||
sample_rate: float = 16000,
|
||||
feature_dim: int = 80,
|
||||
enable_endpoint_detection: bool = False,
|
||||
rule1_min_trailing_silence: int = 2.4,
|
||||
rule2_min_trailing_silence: int = 1.2,
|
||||
rule3_min_utterance_length: int = 20,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -52,6 +57,22 @@ class OnlineRecognizer(object):
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
enable_endpoint_detection:
|
||||
True to enable endpoint detection. False to disable endpoint
|
||||
detection.
|
||||
rule1_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If the duration
|
||||
of trailing silence in seconds is larger than this value, we assume
|
||||
an endpoint is detected.
|
||||
rule2_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If we have decoded
|
||||
something that is nonsilence and if the duration of trailing silence
|
||||
in seconds is larger than this value, we assume an endpoint is
|
||||
detected.
|
||||
rule3_min_utterance_length:
|
||||
Used only when enable_endpoint_detection is True. If the utterance
|
||||
length in seconds is larger than this value, we assume an endpoint
|
||||
is detected.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -72,10 +93,18 @@ class OnlineRecognizer(object):
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
rule1_min_trailing_silence=rule1_min_trailing_silence,
|
||||
rule2_min_trailing_silence=rule2_min_trailing_silence,
|
||||
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
tokens=tokens,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -93,4 +122,10 @@ class OnlineRecognizer(object):
|
||||
return self.recognizer.is_ready(s)
|
||||
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).text
|
||||
return self.recognizer.get_result(s).text.strip()
|
||||
|
||||
def is_endpoint(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.is_endpoint(s)
|
||||
|
||||
def reset(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.reset(s)
|
||||
|
||||
Reference in New Issue
Block a user