add blank_penalty for online transducer (#548)

This commit is contained in:
chiiyeh
2024-01-26 12:12:13 +08:00
committed by GitHub
parent 466a6855c8
commit e7b18a2139
13 changed files with 94 additions and 14 deletions

View File

@@ -216,6 +216,18 @@ def get_args():
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"sound_files",
type=str,
@@ -290,6 +302,7 @@ def main():
lm_scale=args.lm_scale,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
)
elif args.zipformer2_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(

View File

@@ -102,6 +102,17 @@ def get_args():
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
return parser.parse_args()
@@ -130,6 +141,7 @@ def create_recognizer(args):
provider=args.provider,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
)
return recognizer

View File

@@ -111,6 +111,17 @@ def get_args():
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
return parser.parse_args()
@@ -136,6 +147,7 @@ def create_recognizer(args):
provider=args.provider,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
)
return recognizer

View File

@@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
""",
)
def add_blank_penalty_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
def add_endpointing_args(parser: argparse.ArgumentParser):
parser.add_argument(
@@ -284,6 +296,7 @@ def get_args():
add_decoding_args(parser)
add_endpointing_args(parser)
add_hotwords_args(parser)
add_blank_penalty_args(parser)
parser.add_argument(
"--port",
@@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
max_active_paths=args.num_active_paths,
hotwords_score=args.hotwords_score,
hotwords_file=args.hotwords_file,
blank_penalty=args.blank_penalty,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,