diff --git a/python-api-examples/generate-subtitles.py b/python-api-examples/generate-subtitles.py index 1e36dd07..752c9b29 100755 --- a/python-api-examples/generate-subtitles.py +++ b/python-api-examples/generate-subtitles.py @@ -79,8 +79,17 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v --num-threads=2 \ /path/to/test.mp4 +(6) For FireRedAsr models -(6) For WeNet CTC models +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt \ + --fire-red-asr-encoder=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx \ + --fire-red-asr-decoder=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx \ + --num-threads=2 \ + /path/to/test.mp4 + +(7) For WeNet CTC models ./python-api-examples/generate-subtitles.py \ --silero-vad-model=/path/to/silero_vad.onnx \ @@ -174,6 +183,20 @@ def get_args(): help="Number of threads for neural network computation", ) + parser.add_argument( + "--fire-red-asr-encoder", + default="", + type=str, + help="Path to FireRedAsr encoder model", + ) + + parser.add_argument( + "--fire-red-asr-decoder", + default="", + type=str, + help="Path to FireRedAsr decoder model", + ) + parser.add_argument( "--whisper-encoder", default="", @@ -304,6 +327,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor assert len(args.moonshine_encoder) == 0, args.moonshine_encoder assert ( @@ -331,6 +356,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor assert len(args.moonshine_encoder) == 0, args.moonshine_encoder assert ( @@ -353,6 +380,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor assert len(args.moonshine_encoder) == 0, args.moonshine_encoder assert ( @@ -371,6 +400,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.wenet_ctc: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor assert len(args.moonshine_encoder) == 0, args.moonshine_encoder assert ( @@ -392,6 +423,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.whisper_encoder: assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor assert len(args.moonshine_encoder) == 0, args.moonshine_encoder assert ( @@ -411,6 +444,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: tail_paddings=args.whisper_tail_paddings, ) elif args.moonshine_preprocessor: + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder assert_file_exists(args.moonshine_preprocessor) assert_file_exists(args.moonshine_encoder) assert_file_exists(args.moonshine_uncached_decoder) @@ -426,6 +461,15 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: decoding_method=args.decoding_method, debug=args.debug, ) + elif args.fire_red_asr_encoder: + recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr( + encoder=args.fire_red_asr_encoder, + decoder=args.fire_red_asr_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) else: raise ValueError("Please specify at least one model")