Add C++ runtime and Python APIs for Moonshine models (#1473)
This commit is contained in:
50
.github/scripts/test-offline-moonshine.sh
vendored
Executable file
50
.github/scripts/test-offline-moonshine.sh
vendored
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
names=(
|
||||
tiny
|
||||
base
|
||||
)
|
||||
|
||||
for name in ${names[@]}; do
|
||||
log "------------------------------------------------------------"
|
||||
log "Run $name"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
|
||||
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
|
||||
curl -SL -O $repo_url
|
||||
tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
|
||||
rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
|
||||
repo=sherpa-onnx-moonshine-$name-en-int8
|
||||
log "Start testing ${repo_url}"
|
||||
|
||||
log "test int8 onnx"
|
||||
|
||||
time $EXE \
|
||||
--moonshine-preprocessor=$repo/preprocess.onnx \
|
||||
--moonshine-encoder=$repo/encode.int8.onnx \
|
||||
--moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
|
||||
--moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
done
|
||||
10
.github/scripts/test-python.sh
vendored
10
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,16 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "test offline Moonshine"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
|
||||
python3 ./python-api-examples/offline-moonshine-decode-files.py
|
||||
|
||||
rm -rf sherpa-onnx-moonshine-tiny-en-int8
|
||||
|
||||
log "test offline speaker diarization"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
|
||||
13
.github/workflows/linux.yaml
vendored
13
.github/workflows/linux.yaml
vendored
@@ -149,6 +149,19 @@ jobs:
|
||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||
path: install/*
|
||||
|
||||
- name: Test offline Moonshine
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
du -h -d1 .
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
readelf -d build/bin/sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-moonshine.sh
|
||||
du -h -d1 .
|
||||
|
||||
- name: Test offline CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
11
.github/workflows/macos.yaml
vendored
11
.github/workflows/macos.yaml
vendored
@@ -121,6 +121,15 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline Moonshine
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-moonshine.sh
|
||||
|
||||
- name: Test C++ API
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -243,8 +252,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
|
||||
|
||||
- name: Test online transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x64.yaml
vendored
8
.github/workflows/windows-x64.yaml
vendored
@@ -93,6 +93,14 @@ jobs:
|
||||
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
|
||||
path: build/install/*
|
||||
|
||||
- name: Test offline Moonshine for windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-moonshine.sh
|
||||
|
||||
- name: Test C++ API
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x86.yaml
vendored
8
.github/workflows/windows-x86.yaml
vendored
@@ -93,6 +93,14 @@ jobs:
|
||||
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
|
||||
path: build/install/*
|
||||
|
||||
- name: Test offline Moonshine for windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-moonshine.sh
|
||||
|
||||
- name: Test C++ API
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
||||
--feature-dim=80 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(3) For Whisper models
|
||||
(3) For Moonshine models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
|
||||
--num-threads=2 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(4) For Whisper models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
||||
--num-threads=2 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(4) For SenseVoice CTC models
|
||||
(5) For SenseVoice CTC models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
||||
/path/to/test.mp4
|
||||
|
||||
|
||||
(5) For WeNet CTC models
|
||||
(6) For WeNet CTC models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
|
||||
used in this file.
|
||||
"""
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -157,7 +170,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
default=2,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
@@ -208,6 +221,34 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-preprocessor",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine preprocessor model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-uncached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine uncached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-cached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine cached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
@@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
@@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.sense_voice)
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
@@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
elif args.wenet_ctc:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.wenet_ctc)
|
||||
|
||||
@@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
elif args.whisper_encoder:
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=args.whisper_encoder,
|
||||
@@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
task=args.whisper_task,
|
||||
tail_paddings=args.whisper_tail_paddings,
|
||||
)
|
||||
elif args.moonshine_preprocessor:
|
||||
assert_file_exists(args.moonshine_preprocessor)
|
||||
assert_file_exists(args.moonshine_encoder)
|
||||
assert_file_exists(args.moonshine_uncached_decoder)
|
||||
assert_file_exists(args.moonshine_cached_decoder)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||
preprocessor=args.moonshine_preprocessor,
|
||||
encoder=args.moonshine_encoder,
|
||||
uncached_decoder=args.moonshine_uncached_decoder,
|
||||
cached_decoder=args.moonshine_cached_decoder,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please specify at least one model")
|
||||
|
||||
@@ -424,28 +511,32 @@ def main():
|
||||
segment_list = []
|
||||
|
||||
print("Started!")
|
||||
start_t = dt.datetime.now()
|
||||
num_processed_samples = 0
|
||||
|
||||
is_silence = False
|
||||
is_eof = False
|
||||
# TODO(fangjun): Support multithreads
|
||||
while True:
|
||||
# *2 because int16_t has two bytes
|
||||
data = process.stdout.read(frames_per_read * 2)
|
||||
if not data:
|
||||
if is_silence:
|
||||
if is_eof:
|
||||
break
|
||||
is_silence = True
|
||||
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
|
||||
is_eof = True
|
||||
# pad 1 second at the end of the file for the VAD
|
||||
data = np.zeros(1 * args.sample_rate, dtype=np.int16)
|
||||
|
||||
samples = np.frombuffer(data, dtype=np.int16)
|
||||
samples = samples.astype(np.float32) / 32768
|
||||
|
||||
num_processed_samples += samples.shape[0]
|
||||
|
||||
buffer = np.concatenate([buffer, samples])
|
||||
while len(buffer) > window_size:
|
||||
vad.accept_waveform(buffer[:window_size])
|
||||
buffer = buffer[window_size:]
|
||||
|
||||
if is_silence:
|
||||
if is_eof:
|
||||
vad.flush()
|
||||
|
||||
streams = []
|
||||
@@ -471,6 +562,11 @@ def main():
|
||||
seg.text = stream.result.text
|
||||
segment_list.append(seg)
|
||||
|
||||
end_t = dt.datetime.now()
|
||||
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||
duration = num_processed_samples / 16000
|
||||
rtf = elapsed_seconds / duration
|
||||
|
||||
srt_filename = Path(args.sound_file).with_suffix(".srt")
|
||||
with open(srt_filename, "w", encoding="utf-8") as f:
|
||||
for i, seg in enumerate(segment_list):
|
||||
@@ -479,6 +575,9 @@ def main():
|
||||
print("", file=f)
|
||||
|
||||
print(f"Saved to {srt_filename}")
|
||||
print(f"Audio duration:\t{duration:.3f} s")
|
||||
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print("Done!")
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \
|
||||
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
|
||||
|
||||
(5) Use a Whisper model
|
||||
(5) Use a Moonshine model
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt
|
||||
|
||||
(6) Use a Whisper model
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
@@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
||||
|
||||
(5) Use a tdnn model of the yesno recipe from icefall
|
||||
(7) Use a tdnn model of the yesno recipe from icefall
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
|
||||
@@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \
|
||||
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
|
||||
|
||||
(6) Use a Non-streaming SenseVoice model
|
||||
(8) Use a Non-streaming SenseVoice model
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||
@@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def add_moonshine_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--moonshine-preprocessor",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine preprocessor model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-uncached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine uncached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-cached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine cached decoder model",
|
||||
)
|
||||
|
||||
|
||||
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--whisper-encoder",
|
||||
@@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
add_wenet_ctc_model_args(parser)
|
||||
add_tdnn_ctc_model_args(parser)
|
||||
add_whisper_model_args(parser)
|
||||
add_moonshine_model_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
@@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
@@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
@@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.sense_voice)
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
@@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.nemo_ctc)
|
||||
|
||||
@@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.wenet_ctc)
|
||||
|
||||
@@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=args.whisper_encoder,
|
||||
@@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
)
|
||||
elif args.tdnn_model:
|
||||
assert_file_exists(args.tdnn_model)
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
||||
model=args.tdnn_model,
|
||||
@@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
decoding_method=args.decoding_method,
|
||||
provider=args.provider,
|
||||
)
|
||||
elif args.moonshine_preprocessor:
|
||||
assert_file_exists(args.moonshine_preprocessor)
|
||||
assert_file_exists(args.moonshine_encoder)
|
||||
assert_file_exists(args.moonshine_uncached_decoder)
|
||||
assert_file_exists(args.moonshine_cached_decoder)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||
preprocessor=args.moonshine_preprocessor,
|
||||
encoder=args.moonshine_encoder,
|
||||
uncached_decoder=args.moonshine_uncached_decoder,
|
||||
cached_decoder=args.moonshine_cached_decoder,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please specify at least one model")
|
||||
|
||||
|
||||
82
python-api-examples/offline-moonshine-decode-files.py
Normal file
82
python-api-examples/offline-moonshine-decode-files.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a non-streaming Moonshine model from
|
||||
https://github.com/usefulsensors/moonshine
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
For instance,
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
"""
|
||||
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"
|
||||
encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"
|
||||
uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"
|
||||
cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"
|
||||
|
||||
tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"
|
||||
test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"
|
||||
|
||||
if not Path(preprocessor).is_file() or not Path(test_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||
preprocessor=preprocessor,
|
||||
encoder=encoder,
|
||||
uncached_decoder=uncached_decoder,
|
||||
cached_decoder=cached_decoder,
|
||||
tokens=tokens,
|
||||
debug=True,
|
||||
),
|
||||
test_wav,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, wave_filename = create_recognizer()
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||
# sample_rate does not need to be 16000 Hz
|
||||
|
||||
start_t = dt.datetime.now()
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
|
||||
end_t = dt.datetime.now()
|
||||
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||
duration = audio.shape[-1] / sample_rate
|
||||
rtf = elapsed_seconds / duration
|
||||
|
||||
print(stream.result)
|
||||
print(wave_filename)
|
||||
print("Text:", stream.result.text)
|
||||
print(f"Audio duration:\t{duration:.3f} s")
|
||||
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
77
python-api-examples/offline-whisper-decode-files.py
Normal file
77
python-api-examples/offline-whisper-decode-files.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a non-streaming whisper model from
|
||||
https://github.com/openai/whisper
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
For instance,
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
rm sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
"""
|
||||
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
||||
decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
|
||||
tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
|
||||
test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
||||
|
||||
if not Path(encoder).is_file() or not Path(test_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
tokens=tokens,
|
||||
debug=True,
|
||||
),
|
||||
test_wav,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, wave_filename = create_recognizer()
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||
# sample_rate does not need to be 16000 Hz
|
||||
|
||||
start_t = dt.datetime.now()
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
|
||||
end_t = dt.datetime.now()
|
||||
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||
duration = audio.shape[-1] / sample_rate
|
||||
rtf = elapsed_seconds / duration
|
||||
|
||||
print(stream.result)
|
||||
print(wave_filename)
|
||||
print("Text:", stream.result.text)
|
||||
print(f"Audio duration:\t{duration:.3f} s")
|
||||
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script.
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80
|
||||
|
||||
(3) For Whisper models
|
||||
(3) For Moonshine models
|
||||
|
||||
./python-api-examples/vad-with-non-streaming-asr.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
|
||||
--num-threads=2
|
||||
|
||||
(4) For Whisper models
|
||||
|
||||
./python-api-examples/vad-with-non-streaming-asr.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script.
|
||||
--whisper-task=transcribe \
|
||||
--num-threads=2
|
||||
|
||||
(4) For SenseVoice CTC models
|
||||
(5) For SenseVoice CTC models
|
||||
|
||||
./python-api-examples/vad-with-non-streaming-asr.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
@@ -192,6 +203,34 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-preprocessor",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine preprocessor model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-uncached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine uncached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moonshine-cached-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to moonshine cached decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
@@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.sense_voice) == 0, args.sense_voice
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
@@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
assert len(args.sense_voice) == 0, args.sense_voice
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
@@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
elif args.sense_voice:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
assert_file_exists(args.sense_voice)
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
@@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
elif args.whisper_encoder:
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||
assert (
|
||||
len(args.moonshine_uncached_decoder) == 0
|
||||
), args.moonshine_uncached_decoder
|
||||
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=args.whisper_encoder,
|
||||
@@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
task=args.whisper_task,
|
||||
tail_paddings=args.whisper_tail_paddings,
|
||||
)
|
||||
elif args.moonshine_preprocessor:
|
||||
assert_file_exists(args.moonshine_preprocessor)
|
||||
assert_file_exists(args.moonshine_encoder)
|
||||
assert_file_exists(args.moonshine_uncached_decoder)
|
||||
assert_file_exists(args.moonshine_cached_decoder)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||
preprocessor=args.moonshine_preprocessor,
|
||||
encoder=args.moonshine_encoder,
|
||||
uncached_decoder=args.moonshine_uncached_decoder,
|
||||
cached_decoder=args.moonshine_cached_decoder,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please specify at least one model")
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ set(sources
|
||||
offline-lm-config.cc
|
||||
offline-lm.cc
|
||||
offline-model-config.cc
|
||||
offline-moonshine-greedy-search-decoder.cc
|
||||
offline-moonshine-model-config.cc
|
||||
offline-moonshine-model.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model.cc
|
||||
offline-paraformer-greedy-search-decoder.cc
|
||||
|
||||
@@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
zipformer_ctc.Register(po);
|
||||
wenet_ctc.Register(po);
|
||||
sense_voice.Register(po);
|
||||
moonshine.Register(po);
|
||||
|
||||
po->Register("telespeech-ctc", &telespeech_ctc,
|
||||
"Path to model.onnx for telespeech ctc");
|
||||
@@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return sense_voice.Validate();
|
||||
}
|
||||
|
||||
if (!moonshine.preprocessor.empty()) {
|
||||
return moonshine.Validate();
|
||||
}
|
||||
|
||||
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
||||
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
||||
telespeech_ctc.c_str());
|
||||
@@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||
os << "sense_voice=" << sense_voice.ToString() << ", ";
|
||||
os << "moonshine=" << moonshine.ToString() << ", ";
|
||||
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
|
||||
@@ -26,6 +27,7 @@ struct OfflineModelConfig {
|
||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||
OfflineWenetCtcModelConfig wenet_ctc;
|
||||
OfflineSenseVoiceModelConfig sense_voice;
|
||||
OfflineMoonshineModelConfig moonshine;
|
||||
std::string telespeech_ctc;
|
||||
|
||||
std::string tokens;
|
||||
@@ -56,6 +58,7 @@ struct OfflineModelConfig {
|
||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||
const OfflineSenseVoiceModelConfig &sense_voice,
|
||||
const OfflineMoonshineModelConfig &moonshine,
|
||||
const std::string &telespeech_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type,
|
||||
@@ -69,6 +72,7 @@ struct OfflineModelConfig {
|
||||
zipformer_ctc(zipformer_ctc),
|
||||
wenet_ctc(wenet_ctc),
|
||||
sense_voice(sense_voice),
|
||||
moonshine(moonshine),
|
||||
telespeech_ctc(telespeech_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
|
||||
34
sherpa-onnx/csrc/offline-moonshine-decoder.h
Normal file
34
sherpa-onnx/csrc/offline-moonshine-decoder.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineMoonshineDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int32_t> tokens;
|
||||
};
|
||||
|
||||
class OfflineMoonshineDecoder {
|
||||
public:
|
||||
virtual ~OfflineMoonshineDecoder() = default;
|
||||
|
||||
/** Run beam search given the output from the moonshine encoder model.
|
||||
*
|
||||
* @param encoder_out A 3-D tensor of shape (batch_size, T, dim)
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineMoonshineDecoderResult> Decode(
|
||||
Ort::Value encoder_out) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
|
||||
87
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
Normal file
87
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,87 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineMoonshineDecoderResult>
|
||||
OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) {
|
||||
auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
if (encoder_out_shape[0] != 1) {
|
||||
SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]));
|
||||
return {};
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
// encoder_out_shape[1] * 384 is the number of audio samples
|
||||
// 16000 is the sample rate
|
||||
//
|
||||
//
|
||||
// 384 is from the moonshine paper
|
||||
int32_t max_len =
|
||||
static_cast<int32_t>(encoder_out_shape[1] * 384 / 16000.0 * 6);
|
||||
|
||||
int32_t sos = 1;
|
||||
int32_t eos = 2;
|
||||
int32_t seq_len = 1;
|
||||
|
||||
std::vector<int32_t> tokens;
|
||||
|
||||
std::array<int64_t, 2> token_shape = {1, 1};
|
||||
int64_t seq_len_shape = 1;
|
||||
|
||||
Ort::Value token_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &sos, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
Ort::Value seq_len_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
|
||||
|
||||
Ort::Value logits{nullptr};
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
std::tie(logits, states) = model_->ForwardUnCachedDecoder(
|
||||
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out));
|
||||
|
||||
int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
|
||||
for (int32_t i = 0; i != max_len; ++i) {
|
||||
const float *p = logits.GetTensorData<float>();
|
||||
|
||||
int32_t max_token_id = static_cast<int32_t>(
|
||||
std::distance(p, std::max_element(p, p + vocab_size)));
|
||||
if (max_token_id == eos) {
|
||||
break;
|
||||
}
|
||||
tokens.push_back(max_token_id);
|
||||
|
||||
seq_len += 1;
|
||||
|
||||
token_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size());
|
||||
|
||||
seq_len_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
|
||||
|
||||
std::tie(logits, states) = model_->ForwardCachedDecoder(
|
||||
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out),
|
||||
std::move(states));
|
||||
}
|
||||
|
||||
OfflineMoonshineDecoderResult ans;
|
||||
ans.tokens = std::move(tokens);
|
||||
|
||||
return {ans};
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
29
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
Normal file
29
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder {
|
||||
public:
|
||||
explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model)
|
||||
: model_(model) {}
|
||||
|
||||
std::vector<OfflineMoonshineDecoderResult> Decode(
|
||||
Ort::Value encoder_out) override;
|
||||
|
||||
private:
|
||||
OfflineMoonshineModel *model_; // not owned
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
|
||||
88
sherpa-onnx/csrc/offline-moonshine-model-config.cc
Normal file
88
sherpa-onnx/csrc/offline-moonshine-model-config.cc
Normal file
@@ -0,0 +1,88 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineMoonshineModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("moonshine-preprocessor", &preprocessor,
|
||||
"Path to onnx preprocessor of moonshine, e.g., preprocess.onnx");
|
||||
|
||||
po->Register("moonshine-encoder", &encoder,
|
||||
"Path to onnx encoder of moonshine, e.g., encode.onnx");
|
||||
|
||||
po->Register(
|
||||
"moonshine-uncached-decoder", &uncached_decoder,
|
||||
"Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx");
|
||||
|
||||
po->Register(
|
||||
"moonshine-cached-decoder", &cached_decoder,
|
||||
"Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx");
|
||||
}
|
||||
|
||||
bool OfflineMoonshineModelConfig::Validate() const {
|
||||
if (preprocessor.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(preprocessor)) {
|
||||
SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist",
|
||||
preprocessor.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --moonshine-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist",
|
||||
encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (uncached_decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(uncached_decoder)) {
|
||||
SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist",
|
||||
uncached_decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cached_decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(cached_decoder)) {
|
||||
SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist",
|
||||
cached_decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineMoonshineModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineMoonshineModelConfig(";
|
||||
os << "preprocessor=\"" << preprocessor << "\", ";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "uncached_decoder=\"" << uncached_decoder << "\", ";
|
||||
os << "cached_decoder=\"" << cached_decoder << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
37
sherpa-onnx/csrc/offline-moonshine-model-config.h
Normal file
37
sherpa-onnx/csrc/offline-moonshine-model-config.h
Normal file
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineMoonshineModelConfig {
|
||||
std::string preprocessor;
|
||||
std::string encoder;
|
||||
std::string uncached_decoder;
|
||||
std::string cached_decoder;
|
||||
|
||||
OfflineMoonshineModelConfig() = default;
|
||||
OfflineMoonshineModelConfig(const std::string &preprocessor,
|
||||
const std::string &encoder,
|
||||
const std::string &uncached_decoder,
|
||||
const std::string &cached_decoder)
|
||||
: preprocessor(preprocessor),
|
||||
encoder(encoder),
|
||||
uncached_decoder(uncached_decoder),
|
||||
cached_decoder(cached_decoder) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||
282
sherpa-onnx/csrc/offline-moonshine-model.cc
Normal file
282
sherpa-onnx/csrc/offline-moonshine-model.cc
Normal file
@@ -0,0 +1,282 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineMoonshineModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.moonshine.preprocessor);
|
||||
InitPreprocessor(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.moonshine.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.moonshine.uncached_decoder);
|
||||
InitUnCachedDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.moonshine.cached_decoder);
|
||||
InitCachedDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.moonshine.preprocessor);
|
||||
InitPreprocessor(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.moonshine.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.moonshine.uncached_decoder);
|
||||
InitUnCachedDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.moonshine.cached_decoder);
|
||||
InitCachedDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value ForwardPreprocessor(Ort::Value audio) {
|
||||
auto features = preprocessor_sess_->Run(
|
||||
{}, preprocessor_input_names_ptr_.data(), &audio, 1,
|
||||
preprocessor_output_names_ptr_.data(),
|
||||
preprocessor_output_names_ptr_.size());
|
||||
|
||||
return std::move(features[0]);
|
||||
}
|
||||
|
||||
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) {
|
||||
std::array<Ort::Value, 2> encoder_inputs{std::move(features),
|
||||
std::move(features_len)};
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
return std::move(encoder_out[0]);
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
|
||||
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) {
|
||||
std::array<Ort::Value, 3> uncached_decoder_input = {
|
||||
std::move(tokens),
|
||||
std::move(encoder_out),
|
||||
std::move(seq_len),
|
||||
};
|
||||
|
||||
auto uncached_decoder_out = uncached_decoder_sess_->Run(
|
||||
{}, uncached_decoder_input_names_ptr_.data(),
|
||||
uncached_decoder_input.data(), uncached_decoder_input.size(),
|
||||
uncached_decoder_output_names_ptr_.data(),
|
||||
uncached_decoder_output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
states.reserve(uncached_decoder_out.size() - 1);
|
||||
|
||||
int32_t i = -1;
|
||||
for (auto &s : uncached_decoder_out) {
|
||||
++i;
|
||||
if (i == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
states.push_back(std::move(s));
|
||||
}
|
||||
|
||||
return {std::move(uncached_decoder_out[0]), std::move(states)};
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
|
||||
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out,
|
||||
std::vector<Ort::Value> states) {
|
||||
std::vector<Ort::Value> cached_decoder_input;
|
||||
cached_decoder_input.reserve(3 + states.size());
|
||||
cached_decoder_input.push_back(std::move(tokens));
|
||||
cached_decoder_input.push_back(std::move(encoder_out));
|
||||
cached_decoder_input.push_back(std::move(seq_len));
|
||||
|
||||
for (auto &s : states) {
|
||||
cached_decoder_input.push_back(std::move(s));
|
||||
}
|
||||
|
||||
auto cached_decoder_out = cached_decoder_sess_->Run(
|
||||
{}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(),
|
||||
cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(),
|
||||
cached_decoder_output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> next_states;
|
||||
next_states.reserve(cached_decoder_out.size() - 1);
|
||||
|
||||
int32_t i = -1;
|
||||
for (auto &s : cached_decoder_out) {
|
||||
++i;
|
||||
if (i == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
next_states.push_back(std::move(s));
|
||||
}
|
||||
|
||||
return {std::move(cached_decoder_out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitPreprocessor(void *model_data, size_t model_data_length) {
|
||||
preprocessor_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_,
|
||||
&preprocessor_input_names_ptr_);
|
||||
|
||||
GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_,
|
||||
&preprocessor_output_names_ptr_);
|
||||
}
|
||||
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
void InitUnCachedDecoder(void *model_data, size_t model_data_length) {
|
||||
uncached_decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_,
|
||||
&uncached_decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(uncached_decoder_sess_.get(),
|
||||
&uncached_decoder_output_names_,
|
||||
&uncached_decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
void InitCachedDecoder(void *model_data, size_t model_data_length) {
|
||||
cached_decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_,
|
||||
&cached_decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_,
|
||||
&cached_decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> preprocessor_sess_;
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> uncached_decoder_sess_;
|
||||
std::unique_ptr<Ort::Session> cached_decoder_sess_;
|
||||
|
||||
std::vector<std::string> preprocessor_input_names_;
|
||||
std::vector<const char *> preprocessor_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> preprocessor_output_names_;
|
||||
std::vector<const char *> preprocessor_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> uncached_decoder_input_names_;
|
||||
std::vector<const char *> uncached_decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> uncached_decoder_output_names_;
|
||||
std::vector<const char *> uncached_decoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> cached_decoder_input_names_;
|
||||
std::vector<const char *> cached_decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> cached_decoder_output_names_;
|
||||
std::vector<const char *> cached_decoder_output_names_ptr_;
|
||||
};
|
||||
|
||||
OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineMoonshineModel::~OfflineMoonshineModel() = default;
|
||||
|
||||
Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const {
|
||||
return impl_->ForwardPreprocessor(std::move(audio));
|
||||
}
|
||||
|
||||
Ort::Value OfflineMoonshineModel::ForwardEncoder(
|
||||
Ort::Value features, Ort::Value features_len) const {
|
||||
return impl_->ForwardEncoder(std::move(features), std::move(features_len));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token,
|
||||
Ort::Value seq_len,
|
||||
Ort::Value encoder_out) const {
|
||||
return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len),
|
||||
std::move(encoder_out));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OfflineMoonshineModel::ForwardCachedDecoder(
|
||||
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len),
|
||||
std::move(encoder_out), std::move(states));
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineMoonshineModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
93
sherpa-onnx/csrc/offline-moonshine-model.h
Normal file
93
sherpa-onnx/csrc/offline-moonshine-model.h
Normal file
@@ -0,0 +1,93 @@
|
||||
// sherpa-onnx/csrc/offline-moonshine-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// please see
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py
|
||||
class OfflineMoonshineModel {
|
||||
public:
|
||||
explicit OfflineMoonshineModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineMoonshineModel();
|
||||
|
||||
/** Run the preprocessor model.
|
||||
*
|
||||
* @param audio A float32 tensor of shape (batch_size, num_samples)
|
||||
*
|
||||
* @return Return a float32 tensor of shape (batch_size, T, dim) that
|
||||
* can be used as the input of ForwardEncoder()
|
||||
*/
|
||||
Ort::Value ForwardPreprocessor(Ort::Value audio) const;
|
||||
|
||||
/** Run the encoder model.
|
||||
*
|
||||
* @param features A float32 tensor of shape (batch_size, T, dim)
|
||||
* @param features_len A int32 tensor of shape (batch_size,)
|
||||
* @returns A float32 tensor of shape (batch_size, T, dim).
|
||||
*/
|
||||
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const;
|
||||
|
||||
/** Run the uncached decoder.
|
||||
*
|
||||
* @param token A int32 tensor of shape (batch_size, num_tokens)
|
||||
* @param seq_len A int32 tensor of shape (batch_size,) containing number
|
||||
* of predicted tokens so far
|
||||
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
|
||||
*
|
||||
* @returns Return a pair:
|
||||
*
|
||||
* - logits, a float32 tensor of shape (batch_size, 1, dim)
|
||||
* - states, a list of states
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
|
||||
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const;
|
||||
|
||||
/** Run the cached decoder.
|
||||
*
|
||||
* @param token A int32 tensor of shape (batch_size, num_tokens)
|
||||
* @param seq_len A int32 tensor of shape (batch_size,) containing number
|
||||
* of predicted tokens so far
|
||||
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
|
||||
* @param states A list of previous states
|
||||
*
|
||||
* @returns Return a pair:
|
||||
* - logits, a float32 tensor of shape (batch_size, 1, dim)
|
||||
* - states, a list of new states
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
|
||||
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
|
||||
std::vector<Ort::Value> states) const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||
@@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
}
|
||||
|
||||
if (!config.model_config.moonshine.preprocessor.empty()) {
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||
// following models:
|
||||
// 1. transducer and nemo_transducer
|
||||
@@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_type == "telespeech_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else if (model_type == "whisper") {
|
||||
// unreachable
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
} else if (model_type == "moonshine") {
|
||||
// unreachable
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
@@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (!config.model_config.moonshine.preprocessor.empty()) {
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||
// following models:
|
||||
// 1. transducer and nemo_transducer
|
||||
@@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
} else if (model_type == "moonshine") {
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
|
||||
150
sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
Normal file
150
sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
Normal file
@@ -0,0 +1,150 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(
|
||||
const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.Contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = sym_table[i];
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = text;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineMoonshineModel>(config.model_config)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerMoonshineImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(
|
||||
std::make_unique<OfflineMoonshineModel>(mgr, config.model_config)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void Init() {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineMoonshineGreedySearchDecoder>(model_.get());
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only greedy_search is supported at present for moonshine. Given %s",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
MoonshineTag tag;
|
||||
return std::make_unique<OfflineStream>(tag);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
// batch decoding is not implemented yet
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
}
|
||||
}
|
||||
|
||||
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||
|
||||
private:
|
||||
void DecodeStream(OfflineStream *s) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<float> audio = s->GetFrames();
|
||||
|
||||
try {
|
||||
std::array<int64_t, 2> shape{1, static_cast<int64_t>(audio.size())};
|
||||
|
||||
Ort::Value audio_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, audio.data(), audio.size(), shape.data(), shape.size());
|
||||
|
||||
Ort::Value features =
|
||||
model_->ForwardPreprocessor(std::move(audio_tensor));
|
||||
|
||||
int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
int64_t features_shape = 1;
|
||||
|
||||
Ort::Value features_len_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &features_len, 1, &features_shape, 1);
|
||||
|
||||
Ort::Value encoder_out = model_->ForwardEncoder(
|
||||
std::move(features), std::move(features_len_tensor));
|
||||
|
||||
auto results = decoder_->Decode(std::move(encoder_out));
|
||||
|
||||
auto r = Convert(results[0], symbol_table_);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
} catch (const Ort::Exception &ex) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
|
||||
"audio samples: %d",
|
||||
ex.what(), static_cast<int32_t>(audio.size()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineMoonshineModel> model_;
|
||||
std::unique_ptr<OfflineMoonshineDecoder> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
|
||||
@@ -133,6 +133,10 @@ class OfflineStream::Impl {
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) {
|
||||
config_.sampling_rate = 16000;
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (config_.normalize_samples) {
|
||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||
@@ -164,7 +168,9 @@ class OfflineStream::Impl {
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
|
||||
if (fbank_) {
|
||||
if (is_moonshine_) {
|
||||
samples_.insert(samples_.end(), samples.begin(), samples.end());
|
||||
} else if (fbank_) {
|
||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
@@ -181,7 +187,9 @@ class OfflineStream::Impl {
|
||||
return;
|
||||
} // if (sampling_rate != config_.sampling_rate)
|
||||
|
||||
if (fbank_) {
|
||||
if (is_moonshine_) {
|
||||
samples_.insert(samples_.end(), waveform, waveform + n);
|
||||
} else if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
} else if (mfcc_) {
|
||||
@@ -194,10 +202,18 @@ class OfflineStream::Impl {
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const {
|
||||
if (is_moonshine_) {
|
||||
return samples_.size();
|
||||
}
|
||||
|
||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames() const {
|
||||
if (is_moonshine_) {
|
||||
return samples_;
|
||||
}
|
||||
|
||||
int32_t n = fbank_ ? fbank_->NumFramesReady()
|
||||
: mfcc_ ? mfcc_->NumFramesReady()
|
||||
: whisper_fbank_->NumFramesReady();
|
||||
@@ -300,6 +316,10 @@ class OfflineStream::Impl {
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
bool is_ced_ = false;
|
||||
bool is_moonshine_ = false;
|
||||
|
||||
// used only when is_moonshine_== true
|
||||
std::vector<float> samples_;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
@@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag)
|
||||
|
||||
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
|
||||
|
||||
OfflineStream::OfflineStream(MoonshineTag tag)
|
||||
: impl_(std::make_unique<Impl>(tag)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
|
||||
@@ -34,7 +34,7 @@ struct OfflineRecognitionResult {
|
||||
// event target of the audio.
|
||||
std::string event;
|
||||
|
||||
/// timestamps.size() == tokens.size()
|
||||
/// timestamps.size() == tokens.size()
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
@@ -49,6 +49,10 @@ struct WhisperTag {
|
||||
|
||||
struct CEDTag {};
|
||||
|
||||
// It uses a neural network model, a preprocessor, to convert
|
||||
// audio samples to features
|
||||
struct MoonshineTag {};
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const FeatureExtractorConfig &config = {},
|
||||
@@ -56,6 +60,7 @@ class OfflineStream {
|
||||
|
||||
explicit OfflineStream(WhisperTag tag);
|
||||
explicit OfflineStream(CEDTag tag);
|
||||
explicit OfflineStream(MoonshineTag tag);
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
@@ -72,7 +77,10 @@ class OfflineStream {
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) const;
|
||||
|
||||
/// Return feature dim of this extractor
|
||||
/// Return feature dim of this extractor.
|
||||
///
|
||||
/// Note: if it is Moonshine, then it returns the number of audio samples
|
||||
/// currently received.
|
||||
int32_t FeatureDim() const;
|
||||
|
||||
// Get all the feature frames of this stream in a 1-D array, which is
|
||||
|
||||
@@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl {
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
debug_(config.debug),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
@@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl {
|
||||
explicit Impl(const SpokenLanguageIdentificationConfig &config)
|
||||
: lid_config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
debug_(config_.debug),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
@@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
if (debug_) {
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
GetID2Lang().at(lang_id).c_str());
|
||||
}
|
||||
@@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (debug_) {
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
@@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl {
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
SpokenLanguageIdentificationConfig lid_config_;
|
||||
bool debug_ = false;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
@@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in
|
||||
--decoding-method=greedy_search \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
(3) Whisper models
|
||||
(3) Moonshine models
|
||||
|
||||
See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html
|
||||
|
||||
./bin/sherpa-onnx-offline \
|
||||
--moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \
|
||||
--moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \
|
||||
--moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \
|
||||
--moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \
|
||||
--tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \
|
||||
--num-threads=1 \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
(4) Whisper models
|
||||
|
||||
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
|
||||
|
||||
@@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
|
||||
--num-threads=1 \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
(4) NeMo CTC models
|
||||
(5) NeMo CTC models
|
||||
|
||||
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
|
||||
|
||||
@@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
|
||||
|
||||
(5) TDNN CTC model for the yesno recipe from icefall
|
||||
(6) TDNN CTC model for the yesno recipe from icefall
|
||||
|
||||
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
|
||||
//
|
||||
|
||||
@@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const {
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
//
|
||||
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
|
||||
@@ -11,6 +11,7 @@ set(srcs
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-lm-config.cc
|
||||
offline-model-config.cc
|
||||
offline-moonshine-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
offline-paraformer-model-config.cc
|
||||
offline-punctuation.cc
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
|
||||
@@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
PybindOfflineZipformerCtcModelConfig(m);
|
||||
PybindOfflineWenetCtcModelConfig(m);
|
||||
PybindOfflineSenseVoiceModelConfig(m);
|
||||
PybindOfflineMoonshineModelConfig(m);
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
@@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
|
||||
const OfflineZipformerCtcModelConfig &,
|
||||
const OfflineWenetCtcModelConfig &,
|
||||
const OfflineSenseVoiceModelConfig &, const std::string &,
|
||||
const OfflineSenseVoiceModelConfig &,
|
||||
const OfflineMoonshineModelConfig &, const std::string &,
|
||||
const std::string &, int32_t, bool, const std::string &,
|
||||
const std::string &, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
@@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
||||
py::arg("moonshine") = OfflineMoonshineModelConfig(),
|
||||
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
||||
@@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||
.def_readwrite("sense_voice", &PyClass::sense_voice)
|
||||
.def_readwrite("moonshine", &PyClass::moonshine)
|
||||
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
|
||||
28
sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
Normal file
28
sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineMoonshineModelConfig(py::module *m) {
|
||||
using PyClass = OfflineMoonshineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineMoonshineModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string &>(),
|
||||
py::arg("preprocessor"), py::arg("encoder"),
|
||||
py::arg("uncached_decoder"), py::arg("cached_decoder"))
|
||||
.def_readwrite("preprocessor", &PyClass::preprocessor)
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("uncached_decoder", &PyClass::uncached_decoder)
|
||||
.def_readwrite("cached_decoder", &PyClass::cached_decoder)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-moonshine-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-moonshine-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-moonshine-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineMoonshineModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||
@@ -8,13 +8,14 @@ from _sherpa_onnx import (
|
||||
OfflineCtcFstDecoderConfig,
|
||||
OfflineLMConfig,
|
||||
OfflineModelConfig,
|
||||
OfflineMoonshineModelConfig,
|
||||
OfflineNemoEncDecCtcModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
OfflineSenseVoiceModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||
from _sherpa_onnx import (
|
||||
OfflineRecognizerConfig,
|
||||
OfflineSenseVoiceModelConfig,
|
||||
OfflineStream,
|
||||
OfflineTdnnModelConfig,
|
||||
OfflineTransducerModelConfig,
|
||||
@@ -503,12 +504,12 @@ class OfflineRecognizer(object):
|
||||
e.g., tiny, tiny.en, base, base.en, etc.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||
decoder_model:
|
||||
encoder:
|
||||
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||
decoder:
|
||||
Path to the decoder model, e.g., tiny-decoder.onnx,
|
||||
tiny-decoder.int8.onnx, tiny-decoder.ort, etc.
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
@@ -570,6 +571,87 @@ class OfflineRecognizer(object):
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_moonshine(
|
||||
cls,
|
||||
preprocessor: str,
|
||||
encoder: str,
|
||||
uncached_decoder: str,
|
||||
cached_decoder: str,
|
||||
tokens: str,
|
||||
num_threads: int = 1,
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html>`_
|
||||
to download pre-trained models for different kinds of moonshine models,
|
||||
e.g., tiny, base, etc.
|
||||
|
||||
Args:
|
||||
preprocessor:
|
||||
Path to the preprocessor model, e.g., preprocess.onnx
|
||||
encoder:
|
||||
Path to the encoder model, e.g., encode.int8.onnx
|
||||
uncached_decoder:
|
||||
Path to the uncached decoder model, e.g., uncached_decode.int8.onnx,
|
||||
cached_decoder:
|
||||
Path to the cached decoder model, e.g., cached_decode.int8.onnx,
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
decoding_method:
|
||||
Valid values: greedy_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
moonshine=OfflineMoonshineModelConfig(
|
||||
preprocessor=preprocessor,
|
||||
encoder=encoder,
|
||||
uncached_decoder=uncached_decoder,
|
||||
cached_decoder=cached_decoder,
|
||||
),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
unused_feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=16000,
|
||||
feature_dim=80,
|
||||
)
|
||||
|
||||
recognizer_config = OfflineRecognizerConfig(
|
||||
model_config=model_config,
|
||||
feat_config=unused_feat_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_tdnn_ctc(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user