add blank_penalty for offline transducer (#542)
This commit is contained in:
@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
def add_blank_penalty_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def check_args(args):
|
||||
if not Path(args.tokens).is_file():
|
||||
@@ -414,6 +427,7 @@ def get_args():
|
||||
add_feature_config_args(parser)
|
||||
add_decoding_args(parser)
|
||||
add_hotwords_args(parser)
|
||||
add_blank_penalty_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
max_active_paths=args.max_active_paths,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
blank_penalty=args.blank_penalty,
|
||||
provider=args.provider,
|
||||
)
|
||||
elif args.paraformer:
|
||||
|
||||
@@ -231,6 +231,18 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -335,6 +347,7 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
blank_penalty=args.blank_penalty,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.paraformer:
|
||||
|
||||
@@ -177,6 +177,18 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
blank_penalty=args.blank_penalty,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.paraformer:
|
||||
|
||||
Reference in New Issue
Block a user