Add C++ runtime and Python APIs for Moonshine models (#1473)
This commit is contained in:
@@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
||||
--feature-dim=80 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(3) For Whisper models
|
||||
(3) For Moonshine models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
|
||||
--num-threads=2 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(4) For Whisper models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
||||
--num-threads=2 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(4) For SenseVoice CTC models
|
||||
(5) For SenseVoice CTC models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
||||
/path/to/test.mp4
|
||||
|
||||
|
||||
(5) For WeNet CTC models
|
||||
(6) For WeNet CTC models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
|
||||
used in this file.
|
||||
"""
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -157,7 +170,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
default=2,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
@@ -208,6 +221,34 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-preprocessor",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine preprocessor model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-uncached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine uncached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-cached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine cached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
@@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
@@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.sense_voice)
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
@@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
elif args.wenet_ctc:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.wenet_ctc)
|
||||
|
||||
@@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
elif args.whisper_encoder:
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=args.whisper_encoder,
|
||||
@@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
task=args.whisper_task,
|
||||
tail_paddings=args.whisper_tail_paddings,
|
||||
)
|
||||
elif args.moonshine_preprocessor:
|
||||
assert_file_exists(args.moonshine_preprocessor)
|
||||
assert_file_exists(args.moonshine_encoder)
|
||||
assert_file_exists(args.moonshine_uncached_decoder)
|
||||
assert_file_exists(args.moonshine_cached_decoder)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||
preprocessor=args.moonshine_preprocessor,
|
||||
encoder=args.moonshine_encoder,
|
||||
uncached_decoder=args.moonshine_uncached_decoder,
|
||||
cached_decoder=args.moonshine_cached_decoder,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please specify at least one model")
|
||||
|
||||
@@ -424,28 +511,32 @@ def main():
|
||||
segment_list = []
|
||||
|
||||
print("Started!")
|
||||
start_t = dt.datetime.now()
|
||||
num_processed_samples = 0
|
||||
|
||||
is_silence = False
|
||||
is_eof = False
|
||||
# TODO(fangjun): Support multithreads
|
||||
while True:
|
||||
# *2 because int16_t has two bytes
|
||||
data = process.stdout.read(frames_per_read * 2)
|
||||
if not data:
|
||||
if is_silence:
|
||||
if is_eof:
|
||||
break
|
||||
is_silence = True
|
||||
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
|
||||
is_eof = True
|
||||
# pad 1 second at the end of the file for the VAD
|
||||
data = np.zeros(1 * args.sample_rate, dtype=np.int16)
|
||||
|
||||
samples = np.frombuffer(data, dtype=np.int16)
|
||||
samples = samples.astype(np.float32) / 32768
|
||||
|
||||
num_processed_samples += samples.shape[0]
|
||||
|
||||
buffer = np.concatenate([buffer, samples])
|
||||
while len(buffer) > window_size:
|
||||
vad.accept_waveform(buffer[:window_size])
|
||||
buffer = buffer[window_size:]
|
||||
|
||||
if is_silence:
|
||||
if is_eof:
|
||||
vad.flush()
|
||||
|
||||
streams = []
|
||||
@@ -471,6 +562,11 @@ def main():
|
||||
seg.text = stream.result.text
|
||||
segment_list.append(seg)
|
||||
|
||||
end_t = dt.datetime.now()
|
||||
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||
duration = num_processed_samples / 16000
|
||||
rtf = elapsed_seconds / duration
|
||||
|
||||
srt_filename = Path(args.sound_file).with_suffix(".srt")
|
||||
with open(srt_filename, "w", encoding="utf-8") as f:
|
||||
for i, seg in enumerate(segment_list):
|
||||
@@ -479,6 +575,9 @@ def main():
|
||||
print("", file=f)
|
||||
|
||||
print(f"Saved to {srt_filename}")
|
||||
print(f"Audio duration:\t{duration:.3f} s")
|
||||
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print("Done!")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user