# Copyright (c) 2023 Xiaomi Corporation from pathlib import Path from typing import List, Optional from _sherpa_onnx import ( EndpointConfig, FeatureExtractorConfig, OnlineLMConfig, OnlineModelConfig, OnlineParaformerModelConfig, ) from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import ( OnlineRecognizerConfig, OnlineStream, OnlineTransducerModelConfig, OnlineWenetCtcModelConfig, OnlineZipformer2CtcModelConfig, ) def _assert_file_exists(f: str): assert Path(f).is_file(), f"{f} does not exist" class OnlineRecognizer(object): """A class for streaming speech recognition. Please refer to the following files for usages - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py """ @classmethod def from_transducer( cls, tokens: str, encoder: str, decoder: str, joiner: str, num_threads: int = 2, sample_rate: float = 16000, feature_dim: int = 80, enable_endpoint_detection: bool = False, rule1_min_trailing_silence: float = 2.4, rule2_min_trailing_silence: float = 1.2, rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", max_active_paths: int = 4, hotwords_score: float = 1.5, blank_penalty: float = 0.0, hotwords_file: str = "", provider: str = "cpu", model_type: str = "", lm: str = "", lm_scale: float = 0.1, ): """ Please refer to ``_ to download pre-trained models for different languages, e.g., Chinese, English, etc. Args: tokens: Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two columns:: symbol integer_id encoder: Path to ``encoder.onnx``. decoder: Path to ``decoder.onnx``. joiner: Path to ``joiner.onnx``. num_threads: Number of threads for neural network computation. sample_rate: Sample rate of the training data used to train the model. feature_dim: Dimension of the feature used to train the model. enable_endpoint_detection: True to enable endpoint detection. False to disable endpoint detection. rule1_min_trailing_silence: Used only when enable_endpoint_detection is True. If the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule2_min_trailing_silence: Used only when enable_endpoint_detection is True. If we have decoded something that is nonsilence and if the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule3_min_utterance_length: Used only when enable_endpoint_detection is True. If the utterance length in seconds is larger than this value, we assume an endpoint is detected. decoding_method: Valid values are greedy_search, modified_beam_search. max_active_paths: Use only when decoding_method is modified_beam_search. It specifies the maximum number of active paths during beam search. blank_penalty: The penalty applied on blank symbol during decoding. hotwords_file: The file containing hotwords, one words/phrases per line, and for each phrase the bpe/cjkchar are separated by a space. hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. model_type: Online transducer model type. Valid values are: conformer, lstm, zipformer, zipformer2. All other values lead to loading the model twice. """ self = cls.__new__(cls) _assert_file_exists(tokens) _assert_file_exists(encoder) _assert_file_exists(decoder) _assert_file_exists(joiner) assert num_threads > 0, num_threads transducer_config = OnlineTransducerModelConfig( encoder=encoder, decoder=decoder, joiner=joiner, ) model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, num_threads=num_threads, provider=provider, model_type=model_type, ) feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) endpoint_config = EndpointConfig( rule1_min_trailing_silence=rule1_min_trailing_silence, rule2_min_trailing_silence=rule2_min_trailing_silence, rule3_min_utterance_length=rule3_min_utterance_length, ) if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": raise ValueError( "Please use --decoding-method=modified_beam_search when using " f"--hotwords-file. Currently given: {decoding_method}" ) if lm and decoding_method != "modified_beam_search": raise ValueError( "Please use --decoding-method=modified_beam_search when using " f"--lm. Currently given: {decoding_method}" ) lm_config = OnlineLMConfig( model=lm, scale=lm_scale, ) recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, lm_config=lm_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, max_active_paths=max_active_paths, hotwords_score=hotwords_score, hotwords_file=hotwords_file, blank_penalty=blank_penalty, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config return self @classmethod def from_paraformer( cls, tokens: str, encoder: str, decoder: str, num_threads: int = 2, sample_rate: float = 16000, feature_dim: int = 80, enable_endpoint_detection: bool = False, rule1_min_trailing_silence: float = 2.4, rule2_min_trailing_silence: float = 1.2, rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", provider: str = "cpu", ): """ Please refer to ``_ to download pre-trained models for different languages, e.g., Chinese, English, etc. Args: tokens: Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two columns:: symbol integer_id encoder: Path to ``encoder.onnx``. decoder: Path to ``decoder.onnx``. num_threads: Number of threads for neural network computation. sample_rate: Sample rate of the training data used to train the model. feature_dim: Dimension of the feature used to train the model. enable_endpoint_detection: True to enable endpoint detection. False to disable endpoint detection. rule1_min_trailing_silence: Used only when enable_endpoint_detection is True. If the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule2_min_trailing_silence: Used only when enable_endpoint_detection is True. If we have decoded something that is nonsilence and if the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule3_min_utterance_length: Used only when enable_endpoint_detection is True. If the utterance length in seconds is larger than this value, we assume an endpoint is detected. decoding_method: The only valid value is greedy_search. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ self = cls.__new__(cls) _assert_file_exists(tokens) _assert_file_exists(encoder) _assert_file_exists(decoder) assert num_threads > 0, num_threads paraformer_config = OnlineParaformerModelConfig( encoder=encoder, decoder=decoder, ) model_config = OnlineModelConfig( paraformer=paraformer_config, tokens=tokens, num_threads=num_threads, provider=provider, model_type="paraformer", ) feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) endpoint_config = EndpointConfig( rule1_min_trailing_silence=rule1_min_trailing_silence, rule2_min_trailing_silence=rule2_min_trailing_silence, rule3_min_utterance_length=rule3_min_utterance_length, ) recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config return self @classmethod def from_zipformer2_ctc( cls, tokens: str, model: str, num_threads: int = 2, sample_rate: float = 16000, feature_dim: int = 80, enable_endpoint_detection: bool = False, rule1_min_trailing_silence: float = 2.4, rule2_min_trailing_silence: float = 1.2, rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", provider: str = "cpu", ): """ Please refer to ``_ to download pre-trained models for different languages, e.g., Chinese, English, etc. Args: tokens: Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two columns:: symbol integer_id model: Path to ``model.onnx``. num_threads: Number of threads for neural network computation. sample_rate: Sample rate of the training data used to train the model. feature_dim: Dimension of the feature used to train the model. enable_endpoint_detection: True to enable endpoint detection. False to disable endpoint detection. rule1_min_trailing_silence: Used only when enable_endpoint_detection is True. If the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule2_min_trailing_silence: Used only when enable_endpoint_detection is True. If we have decoded something that is nonsilence and if the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule3_min_utterance_length: Used only when enable_endpoint_detection is True. If the utterance length in seconds is larger than this value, we assume an endpoint is detected. decoding_method: The only valid value is greedy_search. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ self = cls.__new__(cls) _assert_file_exists(tokens) _assert_file_exists(model) assert num_threads > 0, num_threads zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) model_config = OnlineModelConfig( zipformer2_ctc=zipformer2_ctc_config, tokens=tokens, num_threads=num_threads, provider=provider, ) feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) endpoint_config = EndpointConfig( rule1_min_trailing_silence=rule1_min_trailing_silence, rule2_min_trailing_silence=rule2_min_trailing_silence, rule3_min_utterance_length=rule3_min_utterance_length, ) recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config return self @classmethod def from_wenet_ctc( cls, tokens: str, model: str, chunk_size: int = 16, num_left_chunks: int = 4, num_threads: int = 2, sample_rate: float = 16000, feature_dim: int = 80, enable_endpoint_detection: bool = False, rule1_min_trailing_silence: float = 2.4, rule2_min_trailing_silence: float = 1.2, rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", provider: str = "cpu", ): """ Please refer to ``_ to download pre-trained models for different languages, e.g., Chinese, English, etc. Args: tokens: Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two columns:: symbol integer_id model: Path to ``model.onnx``. chunk_size: The --chunk-size parameter from WeNet. num_left_chunks: The --num-left-chunks parameter from WeNet. num_threads: Number of threads for neural network computation. sample_rate: Sample rate of the training data used to train the model. feature_dim: Dimension of the feature used to train the model. enable_endpoint_detection: True to enable endpoint detection. False to disable endpoint detection. rule1_min_trailing_silence: Used only when enable_endpoint_detection is True. If the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule2_min_trailing_silence: Used only when enable_endpoint_detection is True. If we have decoded something that is nonsilence and if the duration of trailing silence in seconds is larger than this value, we assume an endpoint is detected. rule3_min_utterance_length: Used only when enable_endpoint_detection is True. If the utterance length in seconds is larger than this value, we assume an endpoint is detected. decoding_method: The only valid value is greedy_search. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ self = cls.__new__(cls) _assert_file_exists(tokens) _assert_file_exists(model) assert num_threads > 0, num_threads wenet_ctc_config = OnlineWenetCtcModelConfig( model=model, chunk_size=chunk_size, num_left_chunks=num_left_chunks, ) model_config = OnlineModelConfig( wenet_ctc=wenet_ctc_config, tokens=tokens, num_threads=num_threads, provider=provider, ) feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) endpoint_config = EndpointConfig( rule1_min_trailing_silence=rule1_min_trailing_silence, rule2_min_trailing_silence=rule2_min_trailing_silence, rule3_min_utterance_length=rule3_min_utterance_length, ) recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config return self def create_stream(self, hotwords: Optional[str] = None): if hotwords is None: return self.recognizer.create_stream() else: return self.recognizer.create_stream(hotwords) def decode_stream(self, s: OnlineStream): self.recognizer.decode_stream(s) def decode_streams(self, ss: List[OnlineStream]): self.recognizer.decode_streams(ss) def is_ready(self, s: OnlineStream) -> bool: return self.recognizer.is_ready(s) def get_result(self, s: OnlineStream) -> str: return self.recognizer.get_result(s).text.strip() def get_result_as_json_string(self, s: OnlineStream) -> str: return self.recognizer.get_result(s).as_json_string() def tokens(self, s: OnlineStream) -> List[str]: return self.recognizer.get_result(s).tokens def timestamps(self, s: OnlineStream) -> List[float]: return self.recognizer.get_result(s).timestamps def start_time(self, s: OnlineStream) -> float: return self.recognizer.get_result(s).start_time def ys_probs(self, s: OnlineStream) -> List[float]: return self.recognizer.get_result(s).ys_probs def lm_probs(self, s: OnlineStream) -> List[float]: return self.recognizer.get_result(s).lm_probs def context_scores(self, s: OnlineStream) -> List[float]: return self.recognizer.get_result(s).context_scores def is_endpoint(self, s: OnlineStream) -> bool: return self.recognizer.is_endpoint(s) def reset(self, s: OnlineStream) -> bool: return self.recognizer.reset(s)