Add endpointing (#54)

This commit is contained in:
Fangjun Kuang
2023-02-22 15:35:55 +08:00
committed by GitHub
parent 1c6f79f096
commit 124384369a
23 changed files with 2190 additions and 21 deletions

View File

@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
features.cc
online-transducer-model-config.cc
sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
)

View File

@@ -0,0 +1,100 @@
// sherpa-onnx/csrc/endpoint.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/endpoint.h"
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/endpoint.h"
namespace sherpa_onnx {
static constexpr const char *kEndpointRuleInitDoc = R"doc(
Constructor for EndpointRule.
Args:
must_contain_nonsilence:
If True, for this endpointing rule to apply there must be nonsilence in the
best-path traceback. For decoding, a non-blank token is considered as
non-silence.
min_trailing_silence:
This endpointing rule requires duration of trailing silence (in seconds)
to be ``>=`` this value.
min_utterance_length:
This endpointing rule requires utterance-length (in seconds) to
be ``>=`` this value.
)doc";
static constexpr const char *kEndpointConfigInitDoc = R"doc(
If any rule in EndpointConfig is activated, it is said that an endpointing
is detected.
Args:
rule1:
By default, it times out after 2.4 seconds of silence, even if
we decoded nothing.
rule2:
By default, it times out after 1.2 seconds of silence after decoding
something.
rule3:
By default, it times out after the utterance is 20 seconds long, regardless of
anything else.
)doc";
static void PybindEndpointRule(py::module *m) {
using PyClass = EndpointRule;
py::class_<PyClass>(*m, "EndpointRule")
.def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
kEndpointRuleInitDoc)
.def("__str__", &PyClass::ToString)
.def_readwrite("must_contain_nonsilence",
&PyClass::must_contain_nonsilence)
.def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
.def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
}
static void PybindEndpointConfig(py::module *m) {
using PyClass = EndpointConfig;
py::class_<PyClass>(*m, "EndpointConfig")
.def(
py::init(
[](float rule1_min_trailing_silence,
float rule2_min_trailing_silence,
float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
EndpointRule rule1(false, rule1_min_trailing_silence, 0);
EndpointRule rule2(true, rule2_min_trailing_silence, 0);
EndpointRule rule3(false, 0, rule3_min_utterance_length);
return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
}),
py::arg("rule1_min_trailing_silence"),
py::arg("rule2_min_trailing_silence"),
py::arg("rule3_min_utterance_length"))
.def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
auto ans = std::make_unique<PyClass>();
ans->rule1 = rule1;
ans->rule2 = rule2;
ans->rule3 = rule3;
return ans;
}),
py::arg("rule1") = EndpointRule(false, 2.4, 0),
py::arg("rule2") = EndpointRule(true, 1.2, 0),
py::arg("rule3") = EndpointRule(false, 0, 20),
kEndpointConfigInitDoc)
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def_readwrite("rule1", &PyClass::rule1)
.def_readwrite("rule2", &PyClass::rule2)
.def_readwrite("rule3", &PyClass::rule3);
}
void PybindEndpoint(py::module *m) {
PybindEndpointRule(m);
PybindEndpointConfig(m);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/csrc/endpoint.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindEndpoint(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_

View File

@@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
const OnlineTransducerModelConfig &, const std::string &,
const EndpointConfig &, bool>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def("__str__", &PyClass::ToString);
}
@@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) {
[](PyClass &self, std::vector<OnlineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
})
.def("get_result", &PyClass::GetResult);
.def("get_result", &PyClass::GetResult)
.def("is_endpoint", &PyClass::IsEndpoint)
.def("reset", &PyClass::Reset);
}
} // namespace sherpa_onnx

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
@@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFeatures(&m);
PybindOnlineTransducerModelConfig(&m);
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
}

View File

@@ -1,4 +1,5 @@
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
OnlineStream,

View File

@@ -2,12 +2,13 @@ from pathlib import Path
from typing import List
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
def _assert_file_exists(f: str):
@@ -26,6 +27,10 @@ class OnlineRecognizer(object):
num_threads: int = 4,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: int = 2.4,
rule2_min_trailing_silence: int = 1.2,
rule3_min_utterance_length: int = 20,
):
"""
Please refer to
@@ -52,6 +57,22 @@ class OnlineRecognizer(object):
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
@@ -72,10 +93,18 @@ class OnlineRecognizer(object):
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
tokens=tokens,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
)
self.recognizer = _Recognizer(recognizer_config)
@@ -93,4 +122,10 @@ class OnlineRecognizer(object):
return self.recognizer.is_ready(s)
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text
return self.recognizer.get_result(s).text.strip()
def is_endpoint(self, s: OnlineStream) -> bool:
return self.recognizer.is_endpoint(s)
def reset(self, s: OnlineStream) -> bool:
return self.recognizer.reset(s)