Support streaming paraformer (#263)

This commit is contained in:
Fangjun Kuang
2023-08-14 10:32:14 +08:00
committed by GitHub
parent a4bff28e21
commit 6038e2aa62
38 changed files with 1488 additions and 112 deletions

View File

@@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
offline-whisper-model-config.cc
online-lm-config.cc
online-model-config.cc
online-paraformer-model-config.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc

View File

@@ -1,6 +1,6 @@
// sherpa-onnx/python/csrc/online-model-config.cc
//
// Copyright (c) 2023 by manyeyes
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-model-config.h"
@@ -9,21 +9,26 @@
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
void PybindOnlineModelConfig(py::module *m) {
PybindOnlineTransducerModelConfig(m);
PybindOnlineParaformerModelConfig(m);
using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &, std::string &, int32_t,
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &, std::string &, int32_t,
bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)

View File

@@ -1,6 +1,6 @@
// sherpa-onnx/python/csrc/online-model-config.h
//
// Copyright (c) 2023 by manyeyes
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_

View File

@@ -0,0 +1,24 @@
// sherpa-onnx/python/csrc/online-paraformer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
namespace sherpa_onnx {
void PybindOnlineParaformerModelConfig(py::module *m) {
using PyClass = OnlineParaformerModelConfig;
py::class_<PyClass>(*m, "OnlineParaformerModelConfig")
.def(py::init<const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"))
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/online-paraformer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineParaformerModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_

View File

@@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths"), py::arg("context_score"))
py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)

View File

@@ -6,6 +6,7 @@ from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineModelConfig,
OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
@@ -32,7 +33,7 @@ class OnlineRecognizer(object):
encoder: str,
decoder: str,
joiner: str,
num_threads: int = 4,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
@@ -144,6 +145,109 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_paraformer(
cls,
tokens: str,
encoder: str,
decoder: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
encoder:
Path to ``encoder.onnx``.
decoder:
Path to ``decoder.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(encoder)
_assert_file_exists(decoder)
assert num_threads > 0, num_threads
paraformer_config = OnlineParaformerModelConfig(
encoder=encoder,
decoder=decoder,
)
model_config = OnlineModelConfig(
paraformer=paraformer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
model_type="paraformer",
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()