Support replacing homonphonic phrases (#2153)

This commit is contained in:
Fangjun Kuang
2025-04-27 15:31:11 +08:00
committed by GitHub
parent e3280027f9
commit f64c58342b
42 changed files with 834 additions and 134 deletions

View File

@@ -5,6 +5,7 @@ from typing import List, Optional
from _sherpa_onnx import (
FeatureExtractorConfig,
HomophoneReplacerConfig,
OfflineCtcFstDecoderConfig,
OfflineDolphinModelConfig,
OfflineFireRedAsrModelConfig,
@@ -64,6 +65,9 @@ class OfflineRecognizer(object):
rule_fars: str = "",
lm: str = "",
lm_scale: float = 0.1,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -181,6 +185,11 @@ class OfflineRecognizer(object):
blank_penalty=blank_penalty,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -201,6 +210,9 @@ class OfflineRecognizer(object):
use_itn: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -263,6 +275,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -281,6 +298,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -336,6 +356,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -354,6 +379,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -411,6 +439,9 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -429,6 +460,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -483,6 +517,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -501,6 +540,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -557,6 +599,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -577,6 +624,9 @@ class OfflineRecognizer(object):
tail_paddings: int = -1,
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -647,6 +697,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -664,6 +719,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -719,6 +777,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -738,6 +801,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -800,6 +866,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -818,6 +889,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -873,6 +947,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -891,6 +970,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -947,6 +1029,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config

View File

@@ -3,25 +3,26 @@ from pathlib import Path
from typing import List, Optional
from _sherpa_onnx import (
CudaConfig,
EndpointConfig,
FeatureExtractorConfig,
HomophoneReplacerConfig,
OnlineCtcFstDecoderConfig,
OnlineLMConfig,
OnlineModelConfig,
OnlineNeMoCtcModelConfig,
OnlineParaformerModelConfig,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
CudaConfig,
TensorrtConfig,
ProviderConfig,
OnlineRecognizerConfig,
OnlineRecognizerResult,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineNeMoCtcModelConfig,
OnlineZipformer2CtcModelConfig,
OnlineCtcFstDecoderConfig,
ProviderConfig,
TensorrtConfig,
)
@@ -82,9 +83,12 @@ class OnlineRecognizer(object):
trt_detailed_build_log: bool = False,
trt_engine_cache_enable: bool = True,
trt_timing_cache_enable: bool = True,
trt_engine_cache_path: str ="",
trt_timing_cache_path: str ="",
trt_engine_cache_path: str = "",
trt_timing_cache_path: str = "",
trt_dump_subgraphs: bool = False,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -228,27 +232,27 @@ class OnlineRecognizer(object):
)
cuda_config = CudaConfig(
cudnn_conv_algo_search=cudnn_conv_algo_search,
cudnn_conv_algo_search=cudnn_conv_algo_search,
)
trt_config = TensorrtConfig(
trt_max_workspace_size=trt_max_workspace_size,
trt_max_partition_iterations=trt_max_partition_iterations,
trt_min_subgraph_size=trt_min_subgraph_size,
trt_fp16_enable=trt_fp16_enable,
trt_detailed_build_log=trt_detailed_build_log,
trt_engine_cache_enable=trt_engine_cache_enable,
trt_timing_cache_enable=trt_timing_cache_enable,
trt_engine_cache_path=trt_engine_cache_path,
trt_timing_cache_path=trt_timing_cache_path,
trt_dump_subgraphs=trt_dump_subgraphs,
trt_max_workspace_size=trt_max_workspace_size,
trt_max_partition_iterations=trt_max_partition_iterations,
trt_min_subgraph_size=trt_min_subgraph_size,
trt_fp16_enable=trt_fp16_enable,
trt_detailed_build_log=trt_detailed_build_log,
trt_engine_cache_enable=trt_engine_cache_enable,
trt_timing_cache_enable=trt_timing_cache_enable,
trt_engine_cache_path=trt_engine_cache_path,
trt_timing_cache_path=trt_timing_cache_path,
trt_dump_subgraphs=trt_dump_subgraphs,
)
provider_config = ProviderConfig(
trt_config=trt_config,
cuda_config=cuda_config,
provider=provider,
device=device,
trt_config=trt_config,
cuda_config=cuda_config,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
@@ -311,6 +315,11 @@ class OnlineRecognizer(object):
rule_fsts=rule_fsts,
rule_fars=rule_fars,
reset_encoder=reset_encoder,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
@@ -336,6 +345,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -402,8 +414,8 @@ class OnlineRecognizer(object):
)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
@@ -434,6 +446,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
@@ -460,6 +477,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -526,8 +546,8 @@ class OnlineRecognizer(object):
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
@@ -563,6 +583,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
@@ -587,6 +612,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -650,8 +678,8 @@ class OnlineRecognizer(object):
)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
@@ -681,6 +709,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
@@ -707,6 +740,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
@@ -775,8 +811,8 @@ class OnlineRecognizer(object):
)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
@@ -806,6 +842,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)