Add C++ runtime and Python APIs for Moonshine models (#1473)

This commit is contained in:
Fangjun Kuang
2024-10-26 14:34:07 +08:00
committed by GitHub
parent 0f2732e4e8
commit 669f5ef441
33 changed files with 1572 additions and 36 deletions

50
.github/scripts/test-offline-moonshine.sh vendored Executable file
View File

@@ -0,0 +1,50 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
names=(
tiny
base
)
for name in ${names[@]}; do
log "------------------------------------------------------------"
log "Run $name"
log "------------------------------------------------------------"
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
curl -SL -O $repo_url
tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
repo=sherpa-onnx-moonshine-$name-en-int8
log "Start testing ${repo_url}"
log "test int8 onnx"
time $EXE \
--moonshine-preprocessor=$repo/preprocess.onnx \
--moonshine-encoder=$repo/encode.int8.onnx \
--moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
--moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
--tokens=$repo/tokens.txt \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
done

View File

@@ -8,6 +8,16 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
} }
log "test offline Moonshine"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
python3 ./python-api-examples/offline-moonshine-decode-files.py
rm -rf sherpa-onnx-moonshine-tiny-en-int8
log "test offline speaker diarization" log "test offline speaker diarization"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

View File

@@ -149,6 +149,19 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/* path: install/*
- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
readelf -d build/bin/sherpa-onnx-offline
.github/scripts/test-offline-moonshine.sh
du -h -d1 .
- name: Test offline CTC - name: Test offline CTC
shell: bash shell: bash
run: | run: |

View File

@@ -121,6 +121,15 @@ jobs:
otool -L build/bin/sherpa-onnx otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx
- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API - name: Test C++ API
shell: bash shell: bash
run: | run: |
@@ -243,8 +252,6 @@ jobs:
.github/scripts/test-offline-whisper.sh .github/scripts/test-offline-whisper.sh
- name: Test online transducer - name: Test online transducer
shell: bash shell: bash
run: | run: |

View File

@@ -93,6 +93,14 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/* path: build/install/*
- name: Test offline Moonshine for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API - name: Test C++ API
shell: bash shell: bash
run: | run: |

View File

@@ -93,6 +93,14 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/* path: build/install/*
- name: Test offline Moonshine for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API - name: Test C++ API
shell: bash shell: bash
run: | run: |

View File

@@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
--feature-dim=80 \ --feature-dim=80 \
/path/to/test.mp4 /path/to/test.mp4
(3) For Whisper models (3) For Moonshine models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2 \
/path/to/test.mp4
(4) For Whisper models
./python-api-examples/generate-subtitles.py \ ./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \ --silero-vad-model=/path/to/silero_vad.onnx \
@@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
--num-threads=2 \ --num-threads=2 \
/path/to/test.mp4 /path/to/test.mp4
(4) For SenseVoice CTC models (5) For SenseVoice CTC models
./python-api-examples/generate-subtitles.py \ ./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \ --silero-vad-model=/path/to/silero_vad.onnx \
@@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
/path/to/test.mp4 /path/to/test.mp4
(5) For WeNet CTC models (6) For WeNet CTC models
./python-api-examples/generate-subtitles.py \ ./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \ --silero-vad-model=/path/to/silero_vad.onnx \
@@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
used in this file. used in this file.
""" """
import argparse import argparse
import datetime as dt
import shutil import shutil
import subprocess import subprocess
import sys import sys
@@ -157,7 +170,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--num-threads", "--num-threads",
type=int, type=int,
default=1, default=2,
help="Number of threads for neural network computation", help="Number of threads for neural network computation",
) )
@@ -208,6 +221,34 @@ def get_args():
""", """,
) )
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
assert_file_exists(args.decoder) assert_file_exists(args.decoder)
@@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer) assert_file_exists(args.paraformer)
@@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice) assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.wenet_ctc: elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.wenet_ctc) assert_file_exists(args.wenet_ctc)
@@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder: elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder) assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder, encoder=args.whisper_encoder,
@@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task, task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings, tail_paddings=args.whisper_tail_paddings,
) )
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else: else:
raise ValueError("Please specify at least one model") raise ValueError("Please specify at least one model")
@@ -424,28 +511,32 @@ def main():
segment_list = [] segment_list = []
print("Started!") print("Started!")
start_t = dt.datetime.now()
num_processed_samples = 0
is_silence = False is_eof = False
# TODO(fangjun): Support multithreads # TODO(fangjun): Support multithreads
while True: while True:
# *2 because int16_t has two bytes # *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2) data = process.stdout.read(frames_per_read * 2)
if not data: if not data:
if is_silence: if is_eof:
break break
is_silence = True is_eof = True
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data # pad 1 second at the end of the file for the VAD
data = np.zeros(1 * args.sample_rate, dtype=np.int16) data = np.zeros(1 * args.sample_rate, dtype=np.int16)
samples = np.frombuffer(data, dtype=np.int16) samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768 samples = samples.astype(np.float32) / 32768
num_processed_samples += samples.shape[0]
buffer = np.concatenate([buffer, samples]) buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size: while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size]) vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:] buffer = buffer[window_size:]
if is_silence: if is_eof:
vad.flush() vad.flush()
streams = [] streams = []
@@ -471,6 +562,11 @@ def main():
seg.text = stream.result.text seg.text = stream.result.text
segment_list.append(seg) segment_list.append(seg)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = num_processed_samples / 16000
rtf = elapsed_seconds / duration
srt_filename = Path(args.sound_file).with_suffix(".srt") srt_filename = Path(args.sound_file).with_suffix(".srt")
with open(srt_filename, "w", encoding="utf-8") as f: with open(srt_filename, "w", encoding="utf-8") as f:
for i, seg in enumerate(segment_list): for i, seg in enumerate(segment_list):
@@ -479,6 +575,9 @@ def main():
print("", file=f) print("", file=f)
print(f"Saved to {srt_filename}") print(f"Saved to {srt_filename}")
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print("Done!") print("Done!")

View File

@@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
(5) Use a Whisper model (5) Use a Moonshine model
cd /path/to/sherpa-onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
python3 ./python-api-examples/non_streaming_server.py \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt
(6) Use a Whisper model
cd /path/to/sherpa-onnx cd /path/to/sherpa-onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
@@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
(5) Use a tdnn model of the yesno recipe from icefall (7) Use a tdnn model of the yesno recipe from icefall
cd /path/to/sherpa-onnx cd /path/to/sherpa-onnx
@@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
(6) Use a Non-streaming SenseVoice model (8) Use a Non-streaming SenseVoice model
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
@@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
) )
def add_moonshine_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
def add_whisper_model_args(parser: argparse.ArgumentParser): def add_whisper_model_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--whisper-encoder", "--whisper-encoder",
@@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_wenet_ctc_model_args(parser) add_wenet_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser) add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser) add_whisper_model_args(parser)
add_moonshine_model_args(parser)
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
@@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
assert_file_exists(args.decoder) assert_file_exists(args.decoder)
@@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer) assert_file_exists(args.paraformer)
@@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice) assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.nemo_ctc) assert_file_exists(args.nemo_ctc)
@@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.wenet_ctc) assert_file_exists(args.wenet_ctc)
@@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.tdnn_model) == 0, args.tdnn_model assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder) assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder, encoder=args.whisper_encoder,
@@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
) )
elif args.tdnn_model: elif args.tdnn_model:
assert_file_exists(args.tdnn_model) assert_file_exists(args.tdnn_model)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model, model=args.tdnn_model,
@@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
provider=args.provider, provider=args.provider,
) )
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
else: else:
raise ValueError("Please specify at least one model") raise ValueError("Please specify at least one model")

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming Moonshine model from
https://github.com/usefulsensors/moonshine
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
"""
import datetime as dt
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"
encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"
uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"
cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"
tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"
test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"
if not Path(preprocessor).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=preprocessor,
encoder=encoder,
uncached_decoder=uncached_decoder,
cached_decoder=cached_decoder,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
start_t = dt.datetime.now()
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print(stream.result)
print(wave_filename)
print("Text:", stream.result.text)
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming whisper model from
https://github.com/openai/whisper
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
rm sherpa-onnx-whisper-tiny.en.tar.bz2
"""
import datetime as dt
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
if not Path(encoder).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=encoder,
decoder=decoder,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
start_t = dt.datetime.now()
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print(stream.result)
print(wave_filename)
print("Text:", stream.result.text)
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
if __name__ == "__main__":
main()

View File

@@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script.
--sample-rate=16000 \ --sample-rate=16000 \
--feature-dim=80 --feature-dim=80
(3) For Whisper models (3) For Moonshine models
./python-api-examples/vad-with-non-streaming-asr.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2
(4) For Whisper models
./python-api-examples/vad-with-non-streaming-asr.py \ ./python-api-examples/vad-with-non-streaming-asr.py \
--silero-vad-model=/path/to/silero_vad.onnx \ --silero-vad-model=/path/to/silero_vad.onnx \
@@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script.
--whisper-task=transcribe \ --whisper-task=transcribe \
--num-threads=2 --num-threads=2
(4) For SenseVoice CTC models (5) For SenseVoice CTC models
./python-api-examples/vad-with-non-streaming-asr.py \ ./python-api-examples/vad-with-non-streaming-asr.py \
--silero-vad-model=/path/to/silero_vad.onnx \ --silero-vad-model=/path/to/silero_vad.onnx \
@@ -192,6 +203,34 @@ def get_args():
""", """,
) )
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
parser.add_argument( parser.add_argument(
"--blank-penalty", "--blank-penalty",
type=float, type=float,
@@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.sense_voice) == 0, args.sense_voice assert len(args.sense_voice) == 0, args.sense_voice
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
assert_file_exists(args.decoder) assert_file_exists(args.decoder)
@@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.sense_voice) == 0, args.sense_voice assert len(args.sense_voice) == 0, args.sense_voice
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer) assert_file_exists(args.paraformer)
@@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.sense_voice: elif args.sense_voice:
assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice) assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder: elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder) assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder, encoder=args.whisper_encoder,
@@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task, task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings, tail_paddings=args.whisper_tail_paddings,
) )
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else: else:
raise ValueError("Please specify at least one model") raise ValueError("Please specify at least one model")

View File

@@ -29,6 +29,9 @@ set(sources
offline-lm-config.cc offline-lm-config.cc
offline-lm.cc offline-lm.cc
offline-model-config.cc offline-model-config.cc
offline-moonshine-greedy-search-decoder.cc
offline-moonshine-model-config.cc
offline-moonshine-model.cc
offline-nemo-enc-dec-ctc-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc
offline-nemo-enc-dec-ctc-model.cc offline-nemo-enc-dec-ctc-model.cc
offline-paraformer-greedy-search-decoder.cc offline-paraformer-greedy-search-decoder.cc

View File

@@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
zipformer_ctc.Register(po); zipformer_ctc.Register(po);
wenet_ctc.Register(po); wenet_ctc.Register(po);
sense_voice.Register(po); sense_voice.Register(po);
moonshine.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc, po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc"); "Path to model.onnx for telespeech ctc");
@@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const {
return sense_voice.Validate(); return sense_voice.Validate();
} }
if (!moonshine.preprocessor.empty()) {
return moonshine.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str()); telespeech_ctc.c_str());
@@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const {
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "sense_voice=" << sense_voice.ToString() << ", "; os << "sense_voice=" << sense_voice.ToString() << ", ";
os << "moonshine=" << moonshine.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", "; os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", "; os << "num_threads=" << num_threads << ", ";

View File

@@ -6,6 +6,7 @@
#include <string> #include <string>
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" #include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
@@ -26,6 +27,7 @@ struct OfflineModelConfig {
OfflineZipformerCtcModelConfig zipformer_ctc; OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc; OfflineWenetCtcModelConfig wenet_ctc;
OfflineSenseVoiceModelConfig sense_voice; OfflineSenseVoiceModelConfig sense_voice;
OfflineMoonshineModelConfig moonshine;
std::string telespeech_ctc; std::string telespeech_ctc;
std::string tokens; std::string tokens;
@@ -56,6 +58,7 @@ struct OfflineModelConfig {
const OfflineZipformerCtcModelConfig &zipformer_ctc, const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc, const OfflineWenetCtcModelConfig &wenet_ctc,
const OfflineSenseVoiceModelConfig &sense_voice, const OfflineSenseVoiceModelConfig &sense_voice,
const OfflineMoonshineModelConfig &moonshine,
const std::string &telespeech_ctc, const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug, const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type, const std::string &provider, const std::string &model_type,
@@ -69,6 +72,7 @@ struct OfflineModelConfig {
zipformer_ctc(zipformer_ctc), zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc), wenet_ctc(wenet_ctc),
sense_voice(sense_voice), sense_voice(sense_voice),
moonshine(moonshine),
telespeech_ctc(telespeech_ctc), telespeech_ctc(telespeech_ctc),
tokens(tokens), tokens(tokens),
num_threads(num_threads), num_threads(num_threads),

View File

@@ -0,0 +1,34 @@
// sherpa-onnx/csrc/offline-moonshine-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineMoonshineDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};
class OfflineMoonshineDecoder {
public:
virtual ~OfflineMoonshineDecoder() = default;
/** Run beam search given the output from the moonshine encoder model.
*
* @param encoder_out A 3-D tensor of shape (batch_size, T, dim)
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineMoonshineDecoderResult> Decode(
Ort::Value encoder_out) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_

View File

@@ -0,0 +1,87 @@
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
std::vector<OfflineMoonshineDecoderResult>
OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) {
auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (encoder_out_shape[0] != 1) {
SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n",
static_cast<int32_t>(encoder_out_shape[0]));
return {};
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// encoder_out_shape[1] * 384 is the number of audio samples
// 16000 is the sample rate
//
//
// 384 is from the moonshine paper
int32_t max_len =
static_cast<int32_t>(encoder_out_shape[1] * 384 / 16000.0 * 6);
int32_t sos = 1;
int32_t eos = 2;
int32_t seq_len = 1;
std::vector<int32_t> tokens;
std::array<int64_t, 2> token_shape = {1, 1};
int64_t seq_len_shape = 1;
Ort::Value token_tensor = Ort::Value::CreateTensor(
memory_info, &sos, 1, token_shape.data(), token_shape.size());
Ort::Value seq_len_tensor =
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
Ort::Value logits{nullptr};
std::vector<Ort::Value> states;
std::tie(logits, states) = model_->ForwardUnCachedDecoder(
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out));
int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2];
for (int32_t i = 0; i != max_len; ++i) {
const float *p = logits.GetTensorData<float>();
int32_t max_token_id = static_cast<int32_t>(
std::distance(p, std::max_element(p, p + vocab_size)));
if (max_token_id == eos) {
break;
}
tokens.push_back(max_token_id);
seq_len += 1;
token_tensor = Ort::Value::CreateTensor(
memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size());
seq_len_tensor =
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
std::tie(logits, states) = model_->ForwardCachedDecoder(
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out),
std::move(states));
}
OfflineMoonshineDecoderResult ans;
ans.tokens = std::move(tokens);
return {ans};
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
namespace sherpa_onnx {
class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder {
public:
explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model)
: model_(model) {}
std::vector<OfflineMoonshineDecoderResult> Decode(
Ort::Value encoder_out) override;
private:
OfflineMoonshineModel *model_; // not owned
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_

View File

@@ -0,0 +1,88 @@
// sherpa-onnx/csrc/offline-moonshine-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineMoonshineModelConfig::Register(ParseOptions *po) {
po->Register("moonshine-preprocessor", &preprocessor,
"Path to onnx preprocessor of moonshine, e.g., preprocess.onnx");
po->Register("moonshine-encoder", &encoder,
"Path to onnx encoder of moonshine, e.g., encode.onnx");
po->Register(
"moonshine-uncached-decoder", &uncached_decoder,
"Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx");
po->Register(
"moonshine-cached-decoder", &cached_decoder,
"Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx");
}
bool OfflineMoonshineModelConfig::Validate() const {
if (preprocessor.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor");
return false;
}
if (!FileExists(preprocessor)) {
SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist",
preprocessor.c_str());
return false;
}
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist",
encoder.c_str());
return false;
}
if (uncached_decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder");
return false;
}
if (!FileExists(uncached_decoder)) {
SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist",
uncached_decoder.c_str());
return false;
}
if (cached_decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder");
return false;
}
if (!FileExists(cached_decoder)) {
SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist",
cached_decoder.c_str());
return false;
}
return true;
}
std::string OfflineMoonshineModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineMoonshineModelConfig(";
os << "preprocessor=\"" << preprocessor << "\", ";
os << "encoder=\"" << encoder << "\", ";
os << "uncached_decoder=\"" << uncached_decoder << "\", ";
os << "cached_decoder=\"" << cached_decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,37 @@
// sherpa-onnx/csrc/offline-moonshine-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineMoonshineModelConfig {
std::string preprocessor;
std::string encoder;
std::string uncached_decoder;
std::string cached_decoder;
OfflineMoonshineModelConfig() = default;
OfflineMoonshineModelConfig(const std::string &preprocessor,
const std::string &encoder,
const std::string &uncached_decoder,
const std::string &cached_decoder)
: preprocessor(preprocessor),
encoder(encoder),
uncached_decoder(uncached_decoder),
cached_decoder(cached_decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_

View File

@@ -0,0 +1,282 @@
// sherpa-onnx/csrc/offline-moonshine-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineMoonshineModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.moonshine.preprocessor);
InitPreprocessor(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.moonshine.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.moonshine.uncached_decoder);
InitUnCachedDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.moonshine.cached_decoder);
InitCachedDecoder(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.moonshine.preprocessor);
InitPreprocessor(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.moonshine.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.moonshine.uncached_decoder);
InitUnCachedDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.moonshine.cached_decoder);
InitCachedDecoder(buf.data(), buf.size());
}
}
#endif
Ort::Value ForwardPreprocessor(Ort::Value audio) {
auto features = preprocessor_sess_->Run(
{}, preprocessor_input_names_ptr_.data(), &audio, 1,
preprocessor_output_names_ptr_.data(),
preprocessor_output_names_ptr_.size());
return std::move(features[0]);
}
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) {
std::array<Ort::Value, 2> encoder_inputs{std::move(features),
std::move(features_len)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return std::move(encoder_out[0]);
}
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) {
std::array<Ort::Value, 3> uncached_decoder_input = {
std::move(tokens),
std::move(encoder_out),
std::move(seq_len),
};
auto uncached_decoder_out = uncached_decoder_sess_->Run(
{}, uncached_decoder_input_names_ptr_.data(),
uncached_decoder_input.data(), uncached_decoder_input.size(),
uncached_decoder_output_names_ptr_.data(),
uncached_decoder_output_names_ptr_.size());
std::vector<Ort::Value> states;
states.reserve(uncached_decoder_out.size() - 1);
int32_t i = -1;
for (auto &s : uncached_decoder_out) {
++i;
if (i == 0) {
continue;
}
states.push_back(std::move(s));
}
return {std::move(uncached_decoder_out[0]), std::move(states)};
}
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> cached_decoder_input;
cached_decoder_input.reserve(3 + states.size());
cached_decoder_input.push_back(std::move(tokens));
cached_decoder_input.push_back(std::move(encoder_out));
cached_decoder_input.push_back(std::move(seq_len));
for (auto &s : states) {
cached_decoder_input.push_back(std::move(s));
}
auto cached_decoder_out = cached_decoder_sess_->Run(
{}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(),
cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(),
cached_decoder_output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(cached_decoder_out.size() - 1);
int32_t i = -1;
for (auto &s : cached_decoder_out) {
++i;
if (i == 0) {
continue;
}
next_states.push_back(std::move(s));
}
return {std::move(cached_decoder_out[0]), std::move(next_states)};
}
OrtAllocator *Allocator() const { return allocator_; }
private:
void InitPreprocessor(void *model_data, size_t model_data_length) {
preprocessor_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_,
&preprocessor_input_names_ptr_);
GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_,
&preprocessor_output_names_ptr_);
}
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
}
void InitUnCachedDecoder(void *model_data, size_t model_data_length) {
uncached_decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_,
&uncached_decoder_input_names_ptr_);
GetOutputNames(uncached_decoder_sess_.get(),
&uncached_decoder_output_names_,
&uncached_decoder_output_names_ptr_);
}
void InitCachedDecoder(void *model_data, size_t model_data_length) {
cached_decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_,
&cached_decoder_input_names_ptr_);
GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_,
&cached_decoder_output_names_ptr_);
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> preprocessor_sess_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> uncached_decoder_sess_;
std::unique_ptr<Ort::Session> cached_decoder_sess_;
std::vector<std::string> preprocessor_input_names_;
std::vector<const char *> preprocessor_input_names_ptr_;
std::vector<std::string> preprocessor_output_names_;
std::vector<const char *> preprocessor_output_names_ptr_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> uncached_decoder_input_names_;
std::vector<const char *> uncached_decoder_input_names_ptr_;
std::vector<std::string> uncached_decoder_output_names_;
std::vector<const char *> uncached_decoder_output_names_ptr_;
std::vector<std::string> cached_decoder_input_names_;
std::vector<const char *> cached_decoder_input_names_ptr_;
std::vector<std::string> cached_decoder_output_names_;
std::vector<const char *> cached_decoder_output_names_ptr_;
};
OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineMoonshineModel::~OfflineMoonshineModel() = default;
Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const {
return impl_->ForwardPreprocessor(std::move(audio));
}
Ort::Value OfflineMoonshineModel::ForwardEncoder(
Ort::Value features, Ort::Value features_len) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_len));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token,
Ort::Value seq_len,
Ort::Value encoder_out) const {
return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len),
std::move(encoder_out));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineMoonshineModel::ForwardCachedDecoder(
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
std::vector<Ort::Value> states) const {
return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len),
std::move(encoder_out), std::move(states));
}
OrtAllocator *OfflineMoonshineModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,93 @@
// sherpa-onnx/csrc/offline-moonshine-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
// please see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py
class OfflineMoonshineModel {
public:
explicit OfflineMoonshineModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineMoonshineModel();
/** Run the preprocessor model.
*
* @param audio A float32 tensor of shape (batch_size, num_samples)
*
* @return Return a float32 tensor of shape (batch_size, T, dim) that
* can be used as the input of ForwardEncoder()
*/
Ort::Value ForwardPreprocessor(Ort::Value audio) const;
/** Run the encoder model.
*
* @param features A float32 tensor of shape (batch_size, T, dim)
* @param features_len A int32 tensor of shape (batch_size,)
* @returns A float32 tensor of shape (batch_size, T, dim).
*/
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const;
/** Run the uncached decoder.
*
* @param token A int32 tensor of shape (batch_size, num_tokens)
* @param seq_len A int32 tensor of shape (batch_size,) containing number
* of predicted tokens so far
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
*
* @returns Return a pair:
*
* - logits, a float32 tensor of shape (batch_size, 1, dim)
* - states, a list of states
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const;
/** Run the cached decoder.
*
* @param token A int32 tensor of shape (batch_size, num_tokens)
* @param seq_len A int32 tensor of shape (batch_size,) containing number
* of predicted tokens so far
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
* @param states A list of previous states
*
* @returns Return a pair:
* - logits, a float32 tensor of shape (batch_size, 1, dim)
* - states, a list of new states
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
std::vector<Ort::Value> states) const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_

View File

@@ -20,6 +20,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
@@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerWhisperImpl>(config); return std::make_unique<OfflineRecognizerWhisperImpl>(config);
} }
if (!config.model_config.moonshine.preprocessor.empty()) {
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the // TODO(fangjun): Refactor it. We only need to use model type for the
// following models: // following models:
// 1. transducer and nemo_transducer // 1. transducer and nemo_transducer
@@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_type == "telespeech_ctc") { model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config); return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") { } else if (model_type == "whisper") {
// unreachable
return std::make_unique<OfflineRecognizerWhisperImpl>(config); return std::make_unique<OfflineRecognizerWhisperImpl>(config);
} else if (model_type == "moonshine") {
// unreachable
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
} else { } else {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type", "Invalid model_type: %s. Trying to load the model to get its type",
@@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
} }
if (!config.model_config.moonshine.preprocessor.empty()) {
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the // TODO(fangjun): Refactor it. We only need to use model type for the
// following models: // following models:
// 1. transducer and nemo_transducer // 1. transducer and nemo_transducer
@@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") { } else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
} else if (model_type == "moonshine") {
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
} else { } else {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type", "Invalid model_type: %s. Trying to load the model to get its type",

View File

@@ -0,0 +1,150 @@
// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(
const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}
const auto &s = sym_table[i];
text += s;
r.tokens.push_back(s);
}
r.text = text;
return r;
}
class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineMoonshineModel>(config.model_config)) {
Init();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerMoonshineImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineMoonshineModel>(mgr, config.model_config)) {
Init();
}
#endif
void Init() {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineMoonshineGreedySearchDecoder>(model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for moonshine. Given %s",
config_.decoding_method.c_str());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const override {
MoonshineTag tag;
return std::make_unique<OfflineStream>(tag);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
// batch decoding is not implemented yet
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<float> audio = s->GetFrames();
try {
std::array<int64_t, 2> shape{1, static_cast<int64_t>(audio.size())};
Ort::Value audio_tensor = Ort::Value::CreateTensor(
memory_info, audio.data(), audio.size(), shape.data(), shape.size());
Ort::Value features =
model_->ForwardPreprocessor(std::move(audio_tensor));
int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1];
int64_t features_shape = 1;
Ort::Value features_len_tensor = Ort::Value::CreateTensor(
memory_info, &features_len, 1, &features_shape, 1);
Ort::Value encoder_out = model_->ForwardEncoder(
std::move(features), std::move(features_len_tensor));
auto results = decoder_->Decode(std::move(encoder_out));
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
"audio samples: %d",
ex.what(), static_cast<int32_t>(audio.size()));
return;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineMoonshineModel> model_;
std::unique_ptr<OfflineMoonshineDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_

View File

@@ -133,6 +133,10 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_); fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
} }
explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) {
config_.sampling_rate = 16000;
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) { if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n); AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -164,7 +168,9 @@ class OfflineStream::Impl {
std::vector<float> samples; std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples); resampler->Resample(waveform, n, true, &samples);
if (fbank_) { if (is_moonshine_) {
samples_.insert(samples_.end(), samples.begin(), samples.end());
} else if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size()); samples.size());
fbank_->InputFinished(); fbank_->InputFinished();
@@ -181,7 +187,9 @@ class OfflineStream::Impl {
return; return;
} // if (sampling_rate != config_.sampling_rate) } // if (sampling_rate != config_.sampling_rate)
if (fbank_) { if (is_moonshine_) {
samples_.insert(samples_.end(), waveform, waveform + n);
} else if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n); fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished(); fbank_->InputFinished();
} else if (mfcc_) { } else if (mfcc_) {
@@ -194,10 +202,18 @@ class OfflineStream::Impl {
} }
int32_t FeatureDim() const { int32_t FeatureDim() const {
if (is_moonshine_) {
return samples_.size();
}
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
} }
std::vector<float> GetFrames() const { std::vector<float> GetFrames() const {
if (is_moonshine_) {
return samples_;
}
int32_t n = fbank_ ? fbank_->NumFramesReady() int32_t n = fbank_ ? fbank_->NumFramesReady()
: mfcc_ ? mfcc_->NumFramesReady() : mfcc_ ? mfcc_->NumFramesReady()
: whisper_fbank_->NumFramesReady(); : whisper_fbank_->NumFramesReady();
@@ -300,6 +316,10 @@ class OfflineStream::Impl {
OfflineRecognitionResult r_; OfflineRecognitionResult r_;
ContextGraphPtr context_graph_; ContextGraphPtr context_graph_;
bool is_ced_ = false; bool is_ced_ = false;
bool is_moonshine_ = false;
// used only when is_moonshine_== true
std::vector<float> samples_;
}; };
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
@@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag)
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {} OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::OfflineStream(MoonshineTag tag)
: impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::~OfflineStream() = default; OfflineStream::~OfflineStream() = default;
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,

View File

@@ -34,7 +34,7 @@ struct OfflineRecognitionResult {
// event target of the audio. // event target of the audio.
std::string event; std::string event;
/// timestamps.size() == tokens.size() /// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded. /// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps; std::vector<float> timestamps;
@@ -49,6 +49,10 @@ struct WhisperTag {
struct CEDTag {}; struct CEDTag {};
// It uses a neural network model, a preprocessor, to convert
// audio samples to features
struct MoonshineTag {};
class OfflineStream { class OfflineStream {
public: public:
explicit OfflineStream(const FeatureExtractorConfig &config = {}, explicit OfflineStream(const FeatureExtractorConfig &config = {},
@@ -56,6 +60,7 @@ class OfflineStream {
explicit OfflineStream(WhisperTag tag); explicit OfflineStream(WhisperTag tag);
explicit OfflineStream(CEDTag tag); explicit OfflineStream(CEDTag tag);
explicit OfflineStream(MoonshineTag tag);
~OfflineStream(); ~OfflineStream();
/** /**
@@ -72,7 +77,10 @@ class OfflineStream {
void AcceptWaveform(int32_t sampling_rate, const float *waveform, void AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) const; int32_t n) const;
/// Return feature dim of this extractor /// Return feature dim of this extractor.
///
/// Note: if it is Moonshine, then it returns the number of audio samples
/// currently received.
int32_t FeatureDim() const; int32_t FeatureDim() const;
// Get all the feature frames of this stream in a 1-D array, which is // Get all the feature frames of this stream in a 1-D array, which is

View File

@@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl {
explicit Impl(const OfflineModelConfig &config) explicit Impl(const OfflineModelConfig &config)
: config_(config), : config_(config),
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
debug_(config.debug),
sess_opts_(GetSessionOptions(config)), sess_opts_(GetSessionOptions(config)),
allocator_{} { allocator_{} {
{ {
@@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl {
explicit Impl(const SpokenLanguageIdentificationConfig &config) explicit Impl(const SpokenLanguageIdentificationConfig &config)
: lid_config_(config), : lid_config_(config),
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
debug_(config_.debug),
sess_opts_(GetSessionOptions(config)), sess_opts_(GetSessionOptions(config)),
allocator_{} { allocator_{} {
{ {
@@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)), sess_opts_(GetSessionOptions(config)),
allocator_{} { allocator_{} {
debug_ = config_.debug;
{ {
auto buf = ReadFile(mgr, config.whisper.encoder); auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size()); InitEncoder(buf.data(), buf.size());
@@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)), sess_opts_(GetSessionOptions(config)),
allocator_{} { allocator_{} {
debug_ = config_.debug;
{ {
auto buf = ReadFile(mgr, config.whisper.encoder); auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size()); InitEncoder(buf.data(), buf.size());
@@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl {
} }
} }
if (debug_) { if (config_.debug) {
SHERPA_ONNX_LOGE("Detected language: %s", SHERPA_ONNX_LOGE("Detected language: %s",
GetID2Lang().at(lang_id).c_str()); GetID2Lang().at(lang_id).c_str());
} }
@@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl {
// get meta data // get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (debug_) { if (config_.debug) {
std::ostringstream os; std::ostringstream os;
os << "---encoder---\n"; os << "---encoder---\n";
PrintModelMetadata(os, meta_data); PrintModelMetadata(os, meta_data);
@@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl {
private: private:
OfflineModelConfig config_; OfflineModelConfig config_;
SpokenLanguageIdentificationConfig lid_config_; SpokenLanguageIdentificationConfig lid_config_;
bool debug_ = false;
Ort::Env env_; Ort::Env env_;
Ort::SessionOptions sess_opts_; Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_; Ort::AllocatorWithDefaultOptions allocator_;

View File

@@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in
--decoding-method=greedy_search \ --decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...] /path/to/foo.wav [bar.wav foobar.wav ...]
(3) Whisper models (3) Moonshine models
See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html
./bin/sherpa-onnx-offline \
--moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \
--moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \
--moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \
--moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \
--tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \
--num-threads=1 \
/path/to/foo.wav [bar.wav foobar.wav ...]
(4) Whisper models
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
@@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
--num-threads=1 \ --num-threads=1 \
/path/to/foo.wav [bar.wav foobar.wav ...] /path/to/foo.wav [bar.wav foobar.wav ...]
(4) NeMo CTC models (5) NeMo CTC models
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
@@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
(5) TDNN CTC model for the yesno recipe from icefall (6) TDNN CTC model for the yesno recipe from icefall
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
// //

View File

@@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const {
// for byte-level BPE // for byte-level BPE
// id 0 is blank, id 1 is sos/eos, id 2 is unk // id 0 is blank, id 1 is sos/eos, id 2 is unk
//
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os; std::ostringstream os;

View File

@@ -11,6 +11,7 @@ set(srcs
offline-ctc-fst-decoder-config.cc offline-ctc-fst-decoder-config.cc
offline-lm-config.cc offline-lm-config.cc
offline-model-config.cc offline-model-config.cc
offline-moonshine-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc
offline-paraformer-model-config.cc offline-paraformer-model-config.cc
offline-punctuation.cc offline-punctuation.cc

View File

@@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" #include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
@@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineZipformerCtcModelConfig(m); PybindOfflineZipformerCtcModelConfig(m);
PybindOfflineWenetCtcModelConfig(m); PybindOfflineWenetCtcModelConfig(m);
PybindOfflineSenseVoiceModelConfig(m); PybindOfflineSenseVoiceModelConfig(m);
PybindOfflineMoonshineModelConfig(m);
using PyClass = OfflineModelConfig; using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig") py::class_<PyClass>(*m, "OfflineModelConfig")
@@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &, const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &, const std::string &, const OfflineSenseVoiceModelConfig &,
const OfflineMoonshineModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &, const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(), const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(), py::arg("transducer") = OfflineTransducerModelConfig(),
@@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("moonshine") = OfflineMoonshineModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"), py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false, py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "", py::arg("provider") = "cpu", py::arg("model_type") = "",
@@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("sense_voice", &PyClass::sense_voice) .def_readwrite("sense_voice", &PyClass::sense_voice)
.def_readwrite("moonshine", &PyClass::moonshine)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens) .def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("num_threads", &PyClass::num_threads)

View File

@@ -0,0 +1,28 @@
// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
namespace sherpa_onnx {
void PybindOfflineMoonshineModelConfig(py::module *m) {
using PyClass = OfflineMoonshineModelConfig;
py::class_<PyClass>(*m, "OfflineMoonshineModelConfig")
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("preprocessor"), py::arg("encoder"),
py::arg("uncached_decoder"), py::arg("cached_decoder"))
.def_readwrite("preprocessor", &PyClass::preprocessor)
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("uncached_decoder", &PyClass::uncached_decoder)
.def_readwrite("cached_decoder", &PyClass::cached_decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-moonshine-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineMoonshineModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_

View File

@@ -8,13 +8,14 @@ from _sherpa_onnx import (
OfflineCtcFstDecoderConfig, OfflineCtcFstDecoderConfig,
OfflineLMConfig, OfflineLMConfig,
OfflineModelConfig, OfflineModelConfig,
OfflineMoonshineModelConfig,
OfflineNemoEncDecCtcModelConfig, OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig, OfflineParaformerModelConfig,
OfflineSenseVoiceModelConfig,
) )
from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import ( from _sherpa_onnx import (
OfflineRecognizerConfig, OfflineRecognizerConfig,
OfflineSenseVoiceModelConfig,
OfflineStream, OfflineStream,
OfflineTdnnModelConfig, OfflineTdnnModelConfig,
OfflineTransducerModelConfig, OfflineTransducerModelConfig,
@@ -503,12 +504,12 @@ class OfflineRecognizer(object):
e.g., tiny, tiny.en, base, base.en, etc. e.g., tiny, tiny.en, base, base.en, etc.
Args: Args:
encoder_model: encoder:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
decoder_model:
Path to the encoder model, e.g., tiny-encoder.onnx, Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc. tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
decoder:
Path to the decoder model, e.g., tiny-decoder.onnx,
tiny-decoder.int8.onnx, tiny-decoder.ort, etc.
tokens: tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns:: columns::
@@ -570,6 +571,87 @@ class OfflineRecognizer(object):
self.config = recognizer_config self.config = recognizer_config
return self return self
@classmethod
def from_moonshine(
cls,
preprocessor: str,
encoder: str,
uncached_decoder: str,
cached_decoder: str,
tokens: str,
num_threads: int = 1,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html>`_
to download pre-trained models for different kinds of moonshine models,
e.g., tiny, base, etc.
Args:
preprocessor:
Path to the preprocessor model, e.g., preprocess.onnx
encoder:
Path to the encoder model, e.g., encode.int8.onnx
uncached_decoder:
Path to the uncached decoder model, e.g., uncached_decode.int8.onnx,
cached_decoder:
Path to the cached decoder model, e.g., cached_decode.int8.onnx,
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
decoding_method:
Valid values: greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
moonshine=OfflineMoonshineModelConfig(
preprocessor=preprocessor,
encoder=encoder,
uncached_decoder=uncached_decoder,
cached_decoder=cached_decoder,
),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
unused_feat_config = FeatureExtractorConfig(
sampling_rate=16000,
feature_dim=80,
)
recognizer_config = OfflineRecognizerConfig(
model_config=model_config,
feat_config=unused_feat_config,
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod @classmethod
def from_tdnn_ctc( def from_tdnn_ctc(
cls, cls,