Add C++ runtime and Python APIs for Moonshine models (#1473)

This commit is contained in:
Fangjun Kuang
2024-10-26 14:34:07 +08:00
committed by GitHub
parent 0f2732e4e8
commit 669f5ef441
33 changed files with 1572 additions and 36 deletions

View File

@@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
--feature-dim=80 \
/path/to/test.mp4
(3) For Whisper models
(3) For Moonshine models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2 \
/path/to/test.mp4
(4) For Whisper models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
@@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
--num-threads=2 \
/path/to/test.mp4
(4) For SenseVoice CTC models
(5) For SenseVoice CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
@@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
/path/to/test.mp4
(5) For WeNet CTC models
(6) For WeNet CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
@@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import datetime as dt
import shutil
import subprocess
import sys
@@ -157,7 +170,7 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
default=1,
default=2,
help="Number of threads for neural network computation",
)
@@ -208,6 +221,34 @@ def get_args():
""",
)
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
parser.add_argument(
"--decoding-method",
type=str,
@@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
@@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer)
@@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.wenet_ctc)
@@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
@@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
)
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
raise ValueError("Please specify at least one model")
@@ -424,28 +511,32 @@ def main():
segment_list = []
print("Started!")
start_t = dt.datetime.now()
num_processed_samples = 0
is_silence = False
is_eof = False
# TODO(fangjun): Support multithreads
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
if is_silence:
if is_eof:
break
is_silence = True
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
is_eof = True
# pad 1 second at the end of the file for the VAD
data = np.zeros(1 * args.sample_rate, dtype=np.int16)
samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768
num_processed_samples += samples.shape[0]
buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]
if is_silence:
if is_eof:
vad.flush()
streams = []
@@ -471,6 +562,11 @@ def main():
seg.text = stream.result.text
segment_list.append(seg)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = num_processed_samples / 16000
rtf = elapsed_seconds / duration
srt_filename = Path(args.sound_file).with_suffix(".srt")
with open(srt_filename, "w", encoding="utf-8") as f:
for i, seg in enumerate(segment_list):
@@ -479,6 +575,9 @@ def main():
print("", file=f)
print(f"Saved to {srt_filename}")
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print("Done!")