Refactor online recognizer (#250)

* Refactor online recognizer.

Make it easier to support other streaming models.

Note that it is a breaking change for the Python API.
`sherpa_onnx.OnlineRecognizer()` used before should be
replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
This commit is contained in:
Fangjun Kuang
2023-08-09 20:27:31 +08:00
committed by GitHub
parent 6061318e3f
commit 79c2ce5dd4
40 changed files with 670 additions and 480 deletions

View File

@@ -5,6 +5,7 @@ from typing import List, Optional
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineModelConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
@@ -24,8 +25,9 @@ class OnlineRecognizer(object):
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py
"""
def __init__(
self,
@classmethod
def from_transducer(
cls,
tokens: str,
encoder: str,
decoder: str,
@@ -95,6 +97,7 @@ class OnlineRecognizer(object):
Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(encoder)
_assert_file_exists(decoder)
@@ -102,10 +105,14 @@ class OnlineRecognizer(object):
assert num_threads > 0, num_threads
model_config = OnlineTransducerModelConfig(
encoder_filename=encoder,
decoder_filename=decoder,
joiner_filename=joiner,
transducer_config = OnlineTransducerModelConfig(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
model_config = OnlineModelConfig(
transducer=transducer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
@@ -135,6 +142,7 @@ class OnlineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None: