Give an informative log for whisper on exceptions. (#473)

This commit is contained in:
Fangjun Kuang
2023-12-08 14:33:59 +08:00
committed by GitHub
parent 868c339e5e
commit 0e23f82691
7 changed files with 77 additions and 15 deletions

View File

@@ -180,6 +180,17 @@ def get_args():
""", """,
) )
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -294,6 +305,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
debug=args.debug, debug=args.debug,
language=args.whisper_language, language=args.whisper_language,
task=args.whisper_task, task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
) )
else: else:
raise ValueError("Please specify at least one model") raise ValueError("Please specify at least one model")

View File

@@ -277,6 +277,17 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
""", """,
) )
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
def add_model_args(parser: argparse.ArgumentParser): def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser) add_transducer_model_args(parser)
@@ -913,6 +924,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
language=args.whisper_language, language=args.whisper_language,
task=args.whisper_task, task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
) )
elif args.tdnn_model: elif args.tdnn_model:
assert_file_exists(args.tdnn_model) assert_file_exists(args.tdnn_model)

View File

@@ -220,6 +220,17 @@ def get_args():
""", """,
) )
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -391,6 +402,7 @@ def main():
debug=args.debug, debug=args.debug,
language=args.whisper_language, language=args.whisper_language,
task=args.whisper_task, task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
) )
elif args.tdnn_model: elif args.tdnn_model:
assert_file_exists(args.tdnn_model) assert_file_exists(args.tdnn_model)

View File

@@ -195,6 +195,17 @@ def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser):
""", """,
) )
parser.add_argument(
"--second-whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser): def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser):
add_second_pass_transducer_model_args(parser) add_second_pass_transducer_model_args(parser)
@@ -314,6 +325,7 @@ def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method="greedy_search", decoding_method="greedy_search",
language=args.second_whisper_language, language=args.second_whisper_language,
task=args.second_whisper_task, task=args.second_whisper_task,
tail_paddings=args.second_whisper_tail_paddings,
) )
else: else:
raise ValueError("Please specify at least one model for the second pass") raise ValueError("Please specify at least one model for the second pass")

View File

@@ -166,6 +166,17 @@ def get_args():
""", """,
) )
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -256,6 +267,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
debug=args.debug, debug=args.debug,
language=args.whisper_language, language=args.whisper_language,
task=args.whisper_task, task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
) )
else: else:
raise ValueError("Please specify at least one model") raise ValueError("Please specify at least one model")

View File

@@ -116,18 +116,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
NormalizeFeatures(f.data(), num_frames, feat_dim); NormalizeFeatures(f.data(), num_frames, feat_dim);
// note that 50 is an experience value. // note that 1000 is an experience-value.
// see also ../../scripts/whisper/test.py // You can replace 1000 by other values, say, 100.
//
// You can replace 50 by other values, say, 100.
// //
// Since we have removed the 30 seconds constraint, we need // Since we have removed the 30 seconds constraint, we need
// tail_padding_frames so that whisper is able to detect the eot token. // tail_padding_frames so that whisper is able to detect the eot token.
int32_t tail_padding_frames = 50; int32_t tail_padding_frames = 1000;
if (model_->IsMultiLingual()) {
// 300 is an experience value. If it throws, please use a larger value.
tail_padding_frames = 300;
}
if (config_.model_config.whisper.tail_paddings > 0) { if (config_.model_config.whisper.tail_paddings > 0) {
tail_padding_frames = config_.model_config.whisper.tail_paddings; tail_padding_frames = config_.model_config.whisper.tail_paddings;
@@ -140,11 +134,13 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
Ort::Value mel = Ort::Value::CreateTensor<float>( Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size()); model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.data(), f.data() + actual_frames * feat_dim, p_mel);
memset(p_mel + f.size(), 0, float *p_mel = mel.GetTensorMutableData<float>();
(actual_frames - num_frames) * feat_dim * sizeof(float)); std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
std::fill_n(p_mel + num_frames * feat_dim,
(actual_frames - num_frames) * feat_dim, 0);
mel = Transpose12(model_->Allocator(), &mel); mel = Transpose12(model_->Allocator(), &mel);
try { try {
@@ -156,8 +152,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
auto r = Convert(results[0], symbol_table_); auto r = Convert(results[0], symbol_table_);
s->SetResult(r); s->SetResult(r);
} catch (const Ort::Exception &ex) { } catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result", SHERPA_ONNX_LOGE(
ex.what()); "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
"input frames: %d, Current tail "
"paddings: %d. If you see a lot of such exceptions, please consider "
"using a larger --whisper-tail-paddings",
ex.what(), num_frames, tail_padding_frames);
return; return;
} }
} }

View File

@@ -261,6 +261,7 @@ class OfflineRecognizer(object):
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
debug: bool = False, debug: bool = False,
provider: str = "cpu", provider: str = "cpu",
tail_paddings: int = -1,
): ):
""" """
Please refer to Please refer to
@@ -305,6 +306,7 @@ class OfflineRecognizer(object):
decoder=decoder, decoder=decoder,
language=language, language=language,
task=task, task=task,
tail_paddings=tail_paddings,
), ),
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,