add blank_penalty for offline transducer (#542)

This commit is contained in:
chiiyeh
2024-01-25 15:00:09 +08:00
committed by GitHub
parent a9e7747736
commit 3bb3849ec5
13 changed files with 97 additions and 14 deletions

View File

@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
""",
)
def add_blank_penalty_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
def check_args(args):
if not Path(args.tokens).is_file():
@@ -414,6 +427,7 @@ def get_args():
add_feature_config_args(parser)
add_decoding_args(parser)
add_hotwords_args(parser)
add_blank_penalty_args(parser)
parser.add_argument(
"--port",
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
max_active_paths=args.max_active_paths,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
provider=args.provider,
)
elif args.paraformer:

View File

@@ -231,6 +231,18 @@ def get_args():
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
@@ -335,6 +347,7 @@ def main():
decoding_method=args.decoding_method,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
debug=args.debug,
)
elif args.paraformer:

View File

@@ -177,6 +177,18 @@ def get_args():
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
blank_penalty=args.blank_penalty,
debug=args.debug,
)
elif args.paraformer: