Add C++ runtime and Python APIs for Moonshine models (#1473)
This commit is contained in:
50
.github/scripts/test-offline-moonshine.sh
vendored
Executable file
50
.github/scripts/test-offline-moonshine.sh
vendored
Executable file
@@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
|
||||||
|
echo "EXE is $EXE"
|
||||||
|
echo "PATH: $PATH"
|
||||||
|
|
||||||
|
which $EXE
|
||||||
|
|
||||||
|
names=(
|
||||||
|
tiny
|
||||||
|
base
|
||||||
|
)
|
||||||
|
|
||||||
|
for name in ${names[@]}; do
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run $name"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
|
||||||
|
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
|
||||||
|
curl -SL -O $repo_url
|
||||||
|
tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
|
||||||
|
rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
|
||||||
|
repo=sherpa-onnx-moonshine-$name-en-int8
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
|
||||||
|
log "test int8 onnx"
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--moonshine-preprocessor=$repo/preprocess.onnx \
|
||||||
|
--moonshine-encoder=$repo/encode.int8.onnx \
|
||||||
|
--moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
|
||||||
|
--moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--num-threads=2 \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
done
|
||||||
10
.github/scripts/test-python.sh
vendored
10
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,16 @@ log() {
|
|||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log "test offline Moonshine"
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-moonshine-decode-files.py
|
||||||
|
|
||||||
|
rm -rf sherpa-onnx-moonshine-tiny-en-int8
|
||||||
|
|
||||||
log "test offline speaker diarization"
|
log "test offline speaker diarization"
|
||||||
|
|
||||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||||
|
|||||||
13
.github/workflows/linux.yaml
vendored
13
.github/workflows/linux.yaml
vendored
@@ -149,6 +149,19 @@ jobs:
|
|||||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||||
path: install/*
|
path: install/*
|
||||||
|
|
||||||
|
- name: Test offline Moonshine
|
||||||
|
if: matrix.build_type != 'Debug'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
du -h -d1 .
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
readelf -d build/bin/sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-moonshine.sh
|
||||||
|
du -h -d1 .
|
||||||
|
|
||||||
- name: Test offline CTC
|
- name: Test offline CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
11
.github/workflows/macos.yaml
vendored
11
.github/workflows/macos.yaml
vendored
@@ -121,6 +121,15 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test offline Moonshine
|
||||||
|
if: matrix.build_type != 'Debug'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-moonshine.sh
|
||||||
|
|
||||||
- name: Test C++ API
|
- name: Test C++ API
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -243,8 +252,6 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-offline-whisper.sh
|
.github/scripts/test-offline-whisper.sh
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- name: Test online transducer
|
- name: Test online transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/windows-x64.yaml
vendored
8
.github/workflows/windows-x64.yaml
vendored
@@ -93,6 +93,14 @@ jobs:
|
|||||||
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
|
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
|
||||||
path: build/install/*
|
path: build/install/*
|
||||||
|
|
||||||
|
- name: Test offline Moonshine for windows x64
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline.exe
|
||||||
|
|
||||||
|
.github/scripts/test-offline-moonshine.sh
|
||||||
|
|
||||||
- name: Test C++ API
|
- name: Test C++ API
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/windows-x86.yaml
vendored
8
.github/workflows/windows-x86.yaml
vendored
@@ -93,6 +93,14 @@ jobs:
|
|||||||
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
|
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
|
||||||
path: build/install/*
|
path: build/install/*
|
||||||
|
|
||||||
|
- name: Test offline Moonshine for windows x86
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline.exe
|
||||||
|
|
||||||
|
.github/scripts/test-offline-moonshine.sh
|
||||||
|
|
||||||
- name: Test C++ API
|
- name: Test C++ API
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
|||||||
--feature-dim=80 \
|
--feature-dim=80 \
|
||||||
/path/to/test.mp4
|
/path/to/test.mp4
|
||||||
|
|
||||||
(3) For Whisper models
|
(3) For Moonshine models
|
||||||
|
|
||||||
|
./python-api-examples/generate-subtitles.py \
|
||||||
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
|
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||||
|
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||||
|
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||||
|
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
|
||||||
|
--num-threads=2 \
|
||||||
|
/path/to/test.mp4
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
./python-api-examples/generate-subtitles.py \
|
./python-api-examples/generate-subtitles.py \
|
||||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
@@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
|||||||
--num-threads=2 \
|
--num-threads=2 \
|
||||||
/path/to/test.mp4
|
/path/to/test.mp4
|
||||||
|
|
||||||
(4) For SenseVoice CTC models
|
(5) For SenseVoice CTC models
|
||||||
|
|
||||||
./python-api-examples/generate-subtitles.py \
|
./python-api-examples/generate-subtitles.py \
|
||||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
@@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
|
|||||||
/path/to/test.mp4
|
/path/to/test.mp4
|
||||||
|
|
||||||
|
|
||||||
(5) For WeNet CTC models
|
(6) For WeNet CTC models
|
||||||
|
|
||||||
./python-api-examples/generate-subtitles.py \
|
./python-api-examples/generate-subtitles.py \
|
||||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
@@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
|
|||||||
used in this file.
|
used in this file.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@@ -157,7 +170,7 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-threads",
|
"--num-threads",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=2,
|
||||||
help="Number of threads for neural network computation",
|
help="Number of threads for neural network computation",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,6 +221,34 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-preprocessor",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine preprocessor model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-uncached-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine uncached decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-cached-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine cached decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
@@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
@@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.sense_voice)
|
assert_file_exists(args.sense_voice)
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||||
@@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
elif args.wenet_ctc:
|
elif args.wenet_ctc:
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.wenet_ctc)
|
assert_file_exists(args.wenet_ctc)
|
||||||
|
|
||||||
@@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
elif args.whisper_encoder:
|
elif args.whisper_encoder:
|
||||||
assert_file_exists(args.whisper_encoder)
|
assert_file_exists(args.whisper_encoder)
|
||||||
assert_file_exists(args.whisper_decoder)
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||||
encoder=args.whisper_encoder,
|
encoder=args.whisper_encoder,
|
||||||
@@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
task=args.whisper_task,
|
task=args.whisper_task,
|
||||||
tail_paddings=args.whisper_tail_paddings,
|
tail_paddings=args.whisper_tail_paddings,
|
||||||
)
|
)
|
||||||
|
elif args.moonshine_preprocessor:
|
||||||
|
assert_file_exists(args.moonshine_preprocessor)
|
||||||
|
assert_file_exists(args.moonshine_encoder)
|
||||||
|
assert_file_exists(args.moonshine_uncached_decoder)
|
||||||
|
assert_file_exists(args.moonshine_cached_decoder)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||||
|
preprocessor=args.moonshine_preprocessor,
|
||||||
|
encoder=args.moonshine_encoder,
|
||||||
|
uncached_decoder=args.moonshine_uncached_decoder,
|
||||||
|
cached_decoder=args.moonshine_cached_decoder,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Please specify at least one model")
|
raise ValueError("Please specify at least one model")
|
||||||
|
|
||||||
@@ -424,28 +511,32 @@ def main():
|
|||||||
segment_list = []
|
segment_list = []
|
||||||
|
|
||||||
print("Started!")
|
print("Started!")
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
num_processed_samples = 0
|
||||||
|
|
||||||
is_silence = False
|
is_eof = False
|
||||||
# TODO(fangjun): Support multithreads
|
# TODO(fangjun): Support multithreads
|
||||||
while True:
|
while True:
|
||||||
# *2 because int16_t has two bytes
|
# *2 because int16_t has two bytes
|
||||||
data = process.stdout.read(frames_per_read * 2)
|
data = process.stdout.read(frames_per_read * 2)
|
||||||
if not data:
|
if not data:
|
||||||
if is_silence:
|
if is_eof:
|
||||||
break
|
break
|
||||||
is_silence = True
|
is_eof = True
|
||||||
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
|
# pad 1 second at the end of the file for the VAD
|
||||||
data = np.zeros(1 * args.sample_rate, dtype=np.int16)
|
data = np.zeros(1 * args.sample_rate, dtype=np.int16)
|
||||||
|
|
||||||
samples = np.frombuffer(data, dtype=np.int16)
|
samples = np.frombuffer(data, dtype=np.int16)
|
||||||
samples = samples.astype(np.float32) / 32768
|
samples = samples.astype(np.float32) / 32768
|
||||||
|
|
||||||
|
num_processed_samples += samples.shape[0]
|
||||||
|
|
||||||
buffer = np.concatenate([buffer, samples])
|
buffer = np.concatenate([buffer, samples])
|
||||||
while len(buffer) > window_size:
|
while len(buffer) > window_size:
|
||||||
vad.accept_waveform(buffer[:window_size])
|
vad.accept_waveform(buffer[:window_size])
|
||||||
buffer = buffer[window_size:]
|
buffer = buffer[window_size:]
|
||||||
|
|
||||||
if is_silence:
|
if is_eof:
|
||||||
vad.flush()
|
vad.flush()
|
||||||
|
|
||||||
streams = []
|
streams = []
|
||||||
@@ -471,6 +562,11 @@ def main():
|
|||||||
seg.text = stream.result.text
|
seg.text = stream.result.text
|
||||||
segment_list.append(seg)
|
segment_list.append(seg)
|
||||||
|
|
||||||
|
end_t = dt.datetime.now()
|
||||||
|
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||||
|
duration = num_processed_samples / 16000
|
||||||
|
rtf = elapsed_seconds / duration
|
||||||
|
|
||||||
srt_filename = Path(args.sound_file).with_suffix(".srt")
|
srt_filename = Path(args.sound_file).with_suffix(".srt")
|
||||||
with open(srt_filename, "w", encoding="utf-8") as f:
|
with open(srt_filename, "w", encoding="utf-8") as f:
|
||||||
for i, seg in enumerate(segment_list):
|
for i, seg in enumerate(segment_list):
|
||||||
@@ -479,6 +575,9 @@ def main():
|
|||||||
print("", file=f)
|
print("", file=f)
|
||||||
|
|
||||||
print(f"Saved to {srt_filename}")
|
print(f"Saved to {srt_filename}")
|
||||||
|
print(f"Audio duration:\t{duration:.3f} s")
|
||||||
|
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||||
|
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \
|
|||||||
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||||
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
|
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
|
||||||
|
|
||||||
(5) Use a Whisper model
|
(5) Use a Moonshine model
|
||||||
|
|
||||||
|
cd /path/to/sherpa-onnx
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
|
||||||
|
python3 ./python-api-examples/non_streaming_server.py \
|
||||||
|
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||||
|
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||||
|
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||||
|
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt
|
||||||
|
|
||||||
|
(6) Use a Whisper model
|
||||||
|
|
||||||
cd /path/to/sherpa-onnx
|
cd /path/to/sherpa-onnx
|
||||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||||
@@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \
|
|||||||
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
||||||
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
||||||
|
|
||||||
(5) Use a tdnn model of the yesno recipe from icefall
|
(7) Use a tdnn model of the yesno recipe from icefall
|
||||||
|
|
||||||
cd /path/to/sherpa-onnx
|
cd /path/to/sherpa-onnx
|
||||||
|
|
||||||
@@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \
|
|||||||
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
|
||||||
|
|
||||||
(6) Use a Non-streaming SenseVoice model
|
(8) Use a Non-streaming SenseVoice model
|
||||||
|
|
||||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
||||||
@@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_moonshine_model_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-preprocessor",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine preprocessor model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-uncached-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine uncached decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-cached-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine cached decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--whisper-encoder",
|
"--whisper-encoder",
|
||||||
@@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||||||
add_wenet_ctc_model_args(parser)
|
add_wenet_ctc_model_args(parser)
|
||||||
add_tdnn_ctc_model_args(parser)
|
add_tdnn_ctc_model_args(parser)
|
||||||
add_whisper_model_args(parser)
|
add_whisper_model_args(parser)
|
||||||
|
add_moonshine_model_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--tokens",
|
||||||
@@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
@@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
@@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.sense_voice)
|
assert_file_exists(args.sense_voice)
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||||
@@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.nemo_ctc)
|
assert_file_exists(args.nemo_ctc)
|
||||||
|
|
||||||
@@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.wenet_ctc)
|
assert_file_exists(args.wenet_ctc)
|
||||||
|
|
||||||
@@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
assert_file_exists(args.whisper_encoder)
|
assert_file_exists(args.whisper_encoder)
|
||||||
assert_file_exists(args.whisper_decoder)
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||||
encoder=args.whisper_encoder,
|
encoder=args.whisper_encoder,
|
||||||
@@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
)
|
)
|
||||||
elif args.tdnn_model:
|
elif args.tdnn_model:
|
||||||
assert_file_exists(args.tdnn_model)
|
assert_file_exists(args.tdnn_model)
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
||||||
model=args.tdnn_model,
|
model=args.tdnn_model,
|
||||||
@@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
provider=args.provider,
|
provider=args.provider,
|
||||||
)
|
)
|
||||||
|
elif args.moonshine_preprocessor:
|
||||||
|
assert_file_exists(args.moonshine_preprocessor)
|
||||||
|
assert_file_exists(args.moonshine_encoder)
|
||||||
|
assert_file_exists(args.moonshine_uncached_decoder)
|
||||||
|
assert_file_exists(args.moonshine_cached_decoder)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||||
|
preprocessor=args.moonshine_preprocessor,
|
||||||
|
encoder=args.moonshine_encoder,
|
||||||
|
uncached_decoder=args.moonshine_uncached_decoder,
|
||||||
|
cached_decoder=args.moonshine_cached_decoder,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Please specify at least one model")
|
raise ValueError("Please specify at least one model")
|
||||||
|
|
||||||
|
|||||||
82
python-api-examples/offline-moonshine-decode-files.py
Normal file
82
python-api-examples/offline-moonshine-decode-files.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming Moonshine model from
|
||||||
|
https://github.com/usefulsensors/moonshine
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
For instance,
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime as dt
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"
|
||||||
|
encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"
|
||||||
|
uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"
|
||||||
|
cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"
|
||||||
|
|
||||||
|
tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"
|
||||||
|
test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"
|
||||||
|
|
||||||
|
if not Path(preprocessor).is_file() or not Path(test_wav).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
encoder=encoder,
|
||||||
|
uncached_decoder=uncached_decoder,
|
||||||
|
cached_decoder=cached_decoder,
|
||||||
|
tokens=tokens,
|
||||||
|
debug=True,
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
|
||||||
|
end_t = dt.datetime.now()
|
||||||
|
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||||
|
duration = audio.shape[-1] / sample_rate
|
||||||
|
rtf = elapsed_seconds / duration
|
||||||
|
|
||||||
|
print(stream.result)
|
||||||
|
print(wave_filename)
|
||||||
|
print("Text:", stream.result.text)
|
||||||
|
print(f"Audio duration:\t{duration:.3f} s")
|
||||||
|
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||||
|
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
77
python-api-examples/offline-whisper-decode-files.py
Normal file
77
python-api-examples/offline-whisper-decode-files.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming whisper model from
|
||||||
|
https://github.com/openai/whisper
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
For instance,
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||||
|
rm sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime as dt
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
||||||
|
decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
|
||||||
|
tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
|
||||||
|
test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
||||||
|
|
||||||
|
if not Path(encoder).is_file() or not Path(test_wav).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
tokens=tokens,
|
||||||
|
debug=True,
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
|
||||||
|
end_t = dt.datetime.now()
|
||||||
|
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||||
|
duration = audio.shape[-1] / sample_rate
|
||||||
|
rtf = elapsed_seconds / duration
|
||||||
|
|
||||||
|
print(stream.result)
|
||||||
|
print(wave_filename)
|
||||||
|
print("Text:", stream.result.text)
|
||||||
|
print(f"Audio duration:\t{duration:.3f} s")
|
||||||
|
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
|
||||||
|
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script.
|
|||||||
--sample-rate=16000 \
|
--sample-rate=16000 \
|
||||||
--feature-dim=80
|
--feature-dim=80
|
||||||
|
|
||||||
(3) For Whisper models
|
(3) For Moonshine models
|
||||||
|
|
||||||
|
./python-api-examples/vad-with-non-streaming-asr.py \
|
||||||
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
|
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
||||||
|
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
||||||
|
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
||||||
|
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
|
||||||
|
--num-threads=2
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
./python-api-examples/vad-with-non-streaming-asr.py \
|
./python-api-examples/vad-with-non-streaming-asr.py \
|
||||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
@@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script.
|
|||||||
--whisper-task=transcribe \
|
--whisper-task=transcribe \
|
||||||
--num-threads=2
|
--num-threads=2
|
||||||
|
|
||||||
(4) For SenseVoice CTC models
|
(5) For SenseVoice CTC models
|
||||||
|
|
||||||
./python-api-examples/vad-with-non-streaming-asr.py \
|
./python-api-examples/vad-with-non-streaming-asr.py \
|
||||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||||
@@ -192,6 +203,34 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-preprocessor",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine preprocessor model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-uncached-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine uncached decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moonshine-cached-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to moonshine cached decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--blank-penalty",
|
"--blank-penalty",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.sense_voice) == 0, args.sense_voice
|
assert len(args.sense_voice) == 0, args.sense_voice
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
@@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
assert len(args.sense_voice) == 0, args.sense_voice
|
assert len(args.sense_voice) == 0, args.sense_voice
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
@@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
elif args.sense_voice:
|
elif args.sense_voice:
|
||||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
assert_file_exists(args.sense_voice)
|
assert_file_exists(args.sense_voice)
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||||
@@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
elif args.whisper_encoder:
|
elif args.whisper_encoder:
|
||||||
assert_file_exists(args.whisper_encoder)
|
assert_file_exists(args.whisper_encoder)
|
||||||
assert_file_exists(args.whisper_decoder)
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
||||||
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
||||||
|
assert (
|
||||||
|
len(args.moonshine_uncached_decoder) == 0
|
||||||
|
), args.moonshine_uncached_decoder
|
||||||
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||||
encoder=args.whisper_encoder,
|
encoder=args.whisper_encoder,
|
||||||
@@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
task=args.whisper_task,
|
task=args.whisper_task,
|
||||||
tail_paddings=args.whisper_tail_paddings,
|
tail_paddings=args.whisper_tail_paddings,
|
||||||
)
|
)
|
||||||
|
elif args.moonshine_preprocessor:
|
||||||
|
assert_file_exists(args.moonshine_preprocessor)
|
||||||
|
assert_file_exists(args.moonshine_encoder)
|
||||||
|
assert_file_exists(args.moonshine_uncached_decoder)
|
||||||
|
assert_file_exists(args.moonshine_cached_decoder)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||||
|
preprocessor=args.moonshine_preprocessor,
|
||||||
|
encoder=args.moonshine_encoder,
|
||||||
|
uncached_decoder=args.moonshine_uncached_decoder,
|
||||||
|
cached_decoder=args.moonshine_cached_decoder,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Please specify at least one model")
|
raise ValueError("Please specify at least one model")
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ set(sources
|
|||||||
offline-lm-config.cc
|
offline-lm-config.cc
|
||||||
offline-lm.cc
|
offline-lm.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
|
offline-moonshine-greedy-search-decoder.cc
|
||||||
|
offline-moonshine-model-config.cc
|
||||||
|
offline-moonshine-model.cc
|
||||||
offline-nemo-enc-dec-ctc-model-config.cc
|
offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
offline-nemo-enc-dec-ctc-model.cc
|
offline-nemo-enc-dec-ctc-model.cc
|
||||||
offline-paraformer-greedy-search-decoder.cc
|
offline-paraformer-greedy-search-decoder.cc
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
zipformer_ctc.Register(po);
|
zipformer_ctc.Register(po);
|
||||||
wenet_ctc.Register(po);
|
wenet_ctc.Register(po);
|
||||||
sense_voice.Register(po);
|
sense_voice.Register(po);
|
||||||
|
moonshine.Register(po);
|
||||||
|
|
||||||
po->Register("telespeech-ctc", &telespeech_ctc,
|
po->Register("telespeech-ctc", &telespeech_ctc,
|
||||||
"Path to model.onnx for telespeech ctc");
|
"Path to model.onnx for telespeech ctc");
|
||||||
@@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return sense_voice.Validate();
|
return sense_voice.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!moonshine.preprocessor.empty()) {
|
||||||
|
return moonshine.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
||||||
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
||||||
telespeech_ctc.c_str());
|
telespeech_ctc.c_str());
|
||||||
@@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||||
os << "sense_voice=" << sense_voice.ToString() << ", ";
|
os << "sense_voice=" << sense_voice.ToString() << ", ";
|
||||||
|
os << "moonshine=" << moonshine.ToString() << ", ";
|
||||||
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
|
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
|
||||||
@@ -26,6 +27,7 @@ struct OfflineModelConfig {
|
|||||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||||
OfflineWenetCtcModelConfig wenet_ctc;
|
OfflineWenetCtcModelConfig wenet_ctc;
|
||||||
OfflineSenseVoiceModelConfig sense_voice;
|
OfflineSenseVoiceModelConfig sense_voice;
|
||||||
|
OfflineMoonshineModelConfig moonshine;
|
||||||
std::string telespeech_ctc;
|
std::string telespeech_ctc;
|
||||||
|
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
@@ -56,6 +58,7 @@ struct OfflineModelConfig {
|
|||||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||||
const OfflineSenseVoiceModelConfig &sense_voice,
|
const OfflineSenseVoiceModelConfig &sense_voice,
|
||||||
|
const OfflineMoonshineModelConfig &moonshine,
|
||||||
const std::string &telespeech_ctc,
|
const std::string &telespeech_ctc,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads, bool debug,
|
||||||
const std::string &provider, const std::string &model_type,
|
const std::string &provider, const std::string &model_type,
|
||||||
@@ -69,6 +72,7 @@ struct OfflineModelConfig {
|
|||||||
zipformer_ctc(zipformer_ctc),
|
zipformer_ctc(zipformer_ctc),
|
||||||
wenet_ctc(wenet_ctc),
|
wenet_ctc(wenet_ctc),
|
||||||
sense_voice(sense_voice),
|
sense_voice(sense_voice),
|
||||||
|
moonshine(moonshine),
|
||||||
telespeech_ctc(telespeech_ctc),
|
telespeech_ctc(telespeech_ctc),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
|
|||||||
34
sherpa-onnx/csrc/offline-moonshine-decoder.h
Normal file
34
sherpa-onnx/csrc/offline-moonshine-decoder.h
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineMoonshineDecoderResult {
|
||||||
|
/// The decoded token IDs
|
||||||
|
std::vector<int32_t> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OfflineMoonshineDecoder {
|
||||||
|
public:
|
||||||
|
virtual ~OfflineMoonshineDecoder() = default;
|
||||||
|
|
||||||
|
/** Run beam search given the output from the moonshine encoder model.
|
||||||
|
*
|
||||||
|
* @param encoder_out A 3-D tensor of shape (batch_size, T, dim)
|
||||||
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
|
*/
|
||||||
|
virtual std::vector<OfflineMoonshineDecoderResult> Decode(
|
||||||
|
Ort::Value encoder_out) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
|
||||||
87
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
Normal file
87
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::vector<OfflineMoonshineDecoderResult>
|
||||||
|
OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) {
|
||||||
|
auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
if (encoder_out_shape[0] != 1) {
|
||||||
|
SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n",
|
||||||
|
static_cast<int32_t>(encoder_out_shape[0]));
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
// encoder_out_shape[1] * 384 is the number of audio samples
|
||||||
|
// 16000 is the sample rate
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// 384 is from the moonshine paper
|
||||||
|
int32_t max_len =
|
||||||
|
static_cast<int32_t>(encoder_out_shape[1] * 384 / 16000.0 * 6);
|
||||||
|
|
||||||
|
int32_t sos = 1;
|
||||||
|
int32_t eos = 2;
|
||||||
|
int32_t seq_len = 1;
|
||||||
|
|
||||||
|
std::vector<int32_t> tokens;
|
||||||
|
|
||||||
|
std::array<int64_t, 2> token_shape = {1, 1};
|
||||||
|
int64_t seq_len_shape = 1;
|
||||||
|
|
||||||
|
Ort::Value token_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &sos, 1, token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
|
Ort::Value seq_len_tensor =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
|
||||||
|
|
||||||
|
Ort::Value logits{nullptr};
|
||||||
|
std::vector<Ort::Value> states;
|
||||||
|
|
||||||
|
std::tie(logits, states) = model_->ForwardUnCachedDecoder(
|
||||||
|
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out));
|
||||||
|
|
||||||
|
int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != max_len; ++i) {
|
||||||
|
const float *p = logits.GetTensorData<float>();
|
||||||
|
|
||||||
|
int32_t max_token_id = static_cast<int32_t>(
|
||||||
|
std::distance(p, std::max_element(p, p + vocab_size)));
|
||||||
|
if (max_token_id == eos) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokens.push_back(max_token_id);
|
||||||
|
|
||||||
|
seq_len += 1;
|
||||||
|
|
||||||
|
token_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
|
seq_len_tensor =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
|
||||||
|
|
||||||
|
std::tie(logits, states) = model_->ForwardCachedDecoder(
|
||||||
|
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out),
|
||||||
|
std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineMoonshineDecoderResult ans;
|
||||||
|
ans.tokens = std::move(tokens);
|
||||||
|
|
||||||
|
return {ans};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
29
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
Normal file
29
sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder {
|
||||||
|
public:
|
||||||
|
explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model)
|
||||||
|
: model_(model) {}
|
||||||
|
|
||||||
|
std::vector<OfflineMoonshineDecoderResult> Decode(
|
||||||
|
Ort::Value encoder_out) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineMoonshineModel *model_; // not owned
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
|
||||||
88
sherpa-onnx/csrc/offline-moonshine-model-config.cc
Normal file
88
sherpa-onnx/csrc/offline-moonshine-model-config.cc
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineMoonshineModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("moonshine-preprocessor", &preprocessor,
|
||||||
|
"Path to onnx preprocessor of moonshine, e.g., preprocess.onnx");
|
||||||
|
|
||||||
|
po->Register("moonshine-encoder", &encoder,
|
||||||
|
"Path to onnx encoder of moonshine, e.g., encode.onnx");
|
||||||
|
|
||||||
|
po->Register(
|
||||||
|
"moonshine-uncached-decoder", &uncached_decoder,
|
||||||
|
"Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx");
|
||||||
|
|
||||||
|
po->Register(
|
||||||
|
"moonshine-cached-decoder", &cached_decoder,
|
||||||
|
"Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineMoonshineModelConfig::Validate() const {
|
||||||
|
if (preprocessor.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(preprocessor)) {
|
||||||
|
SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist",
|
||||||
|
preprocessor.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (encoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --moonshine-encoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(encoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist",
|
||||||
|
encoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (uncached_decoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(uncached_decoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist",
|
||||||
|
uncached_decoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cached_decoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(cached_decoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist",
|
||||||
|
cached_decoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineMoonshineModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineMoonshineModelConfig(";
|
||||||
|
os << "preprocessor=\"" << preprocessor << "\", ";
|
||||||
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
|
os << "uncached_decoder=\"" << uncached_decoder << "\", ";
|
||||||
|
os << "cached_decoder=\"" << cached_decoder << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
37
sherpa-onnx/csrc/offline-moonshine-model-config.h
Normal file
37
sherpa-onnx/csrc/offline-moonshine-model-config.h
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineMoonshineModelConfig {
|
||||||
|
std::string preprocessor;
|
||||||
|
std::string encoder;
|
||||||
|
std::string uncached_decoder;
|
||||||
|
std::string cached_decoder;
|
||||||
|
|
||||||
|
OfflineMoonshineModelConfig() = default;
|
||||||
|
OfflineMoonshineModelConfig(const std::string &preprocessor,
|
||||||
|
const std::string &encoder,
|
||||||
|
const std::string &uncached_decoder,
|
||||||
|
const std::string &cached_decoder)
|
||||||
|
: preprocessor(preprocessor),
|
||||||
|
encoder(encoder),
|
||||||
|
uncached_decoder(uncached_decoder),
|
||||||
|
cached_decoder(cached_decoder) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||||
282
sherpa-onnx/csrc/offline-moonshine-model.cc
Normal file
282
sherpa-onnx/csrc/offline-moonshine-model.cc
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineMoonshineModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.moonshine.preprocessor);
|
||||||
|
InitPreprocessor(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.moonshine.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.moonshine.uncached_decoder);
|
||||||
|
InitUnCachedDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.moonshine.cached_decoder);
|
||||||
|
InitCachedDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.moonshine.preprocessor);
|
||||||
|
InitPreprocessor(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.moonshine.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.moonshine.uncached_decoder);
|
||||||
|
InitUnCachedDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.moonshine.cached_decoder);
|
||||||
|
InitCachedDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Ort::Value ForwardPreprocessor(Ort::Value audio) {
|
||||||
|
auto features = preprocessor_sess_->Run(
|
||||||
|
{}, preprocessor_input_names_ptr_.data(), &audio, 1,
|
||||||
|
preprocessor_output_names_ptr_.data(),
|
||||||
|
preprocessor_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::move(features[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) {
|
||||||
|
std::array<Ort::Value, 2> encoder_inputs{std::move(features),
|
||||||
|
std::move(features_len)};
|
||||||
|
auto encoder_out = encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||||
|
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||||
|
encoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::move(encoder_out[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
|
||||||
|
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) {
|
||||||
|
std::array<Ort::Value, 3> uncached_decoder_input = {
|
||||||
|
std::move(tokens),
|
||||||
|
std::move(encoder_out),
|
||||||
|
std::move(seq_len),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto uncached_decoder_out = uncached_decoder_sess_->Run(
|
||||||
|
{}, uncached_decoder_input_names_ptr_.data(),
|
||||||
|
uncached_decoder_input.data(), uncached_decoder_input.size(),
|
||||||
|
uncached_decoder_output_names_ptr_.data(),
|
||||||
|
uncached_decoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
std::vector<Ort::Value> states;
|
||||||
|
states.reserve(uncached_decoder_out.size() - 1);
|
||||||
|
|
||||||
|
int32_t i = -1;
|
||||||
|
for (auto &s : uncached_decoder_out) {
|
||||||
|
++i;
|
||||||
|
if (i == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
states.push_back(std::move(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::move(uncached_decoder_out[0]), std::move(states)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
|
||||||
|
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out,
|
||||||
|
std::vector<Ort::Value> states) {
|
||||||
|
std::vector<Ort::Value> cached_decoder_input;
|
||||||
|
cached_decoder_input.reserve(3 + states.size());
|
||||||
|
cached_decoder_input.push_back(std::move(tokens));
|
||||||
|
cached_decoder_input.push_back(std::move(encoder_out));
|
||||||
|
cached_decoder_input.push_back(std::move(seq_len));
|
||||||
|
|
||||||
|
for (auto &s : states) {
|
||||||
|
cached_decoder_input.push_back(std::move(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cached_decoder_out = cached_decoder_sess_->Run(
|
||||||
|
{}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(),
|
||||||
|
cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(),
|
||||||
|
cached_decoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
std::vector<Ort::Value> next_states;
|
||||||
|
next_states.reserve(cached_decoder_out.size() - 1);
|
||||||
|
|
||||||
|
int32_t i = -1;
|
||||||
|
for (auto &s : cached_decoder_out) {
|
||||||
|
++i;
|
||||||
|
if (i == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
next_states.push_back(std::move(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::move(cached_decoder_out[0]), std::move(next_states)};
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitPreprocessor(void *model_data, size_t model_data_length) {
|
||||||
|
preprocessor_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_,
|
||||||
|
&preprocessor_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_,
|
||||||
|
&preprocessor_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitUnCachedDecoder(void *model_data, size_t model_data_length) {
|
||||||
|
uncached_decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_,
|
||||||
|
&uncached_decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(uncached_decoder_sess_.get(),
|
||||||
|
&uncached_decoder_output_names_,
|
||||||
|
&uncached_decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitCachedDecoder(void *model_data, size_t model_data_length) {
|
||||||
|
cached_decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_,
|
||||||
|
&cached_decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_,
|
||||||
|
&cached_decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> preprocessor_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> uncached_decoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> cached_decoder_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> preprocessor_input_names_;
|
||||||
|
std::vector<const char *> preprocessor_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> preprocessor_output_names_;
|
||||||
|
std::vector<const char *> preprocessor_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> uncached_decoder_input_names_;
|
||||||
|
std::vector<const char *> uncached_decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> uncached_decoder_output_names_;
|
||||||
|
std::vector<const char *> uncached_decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> cached_decoder_input_names_;
|
||||||
|
std::vector<const char *> cached_decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> cached_decoder_output_names_;
|
||||||
|
std::vector<const char *> cached_decoder_output_names_ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr,
|
||||||
|
const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
OfflineMoonshineModel::~OfflineMoonshineModel() = default;
|
||||||
|
|
||||||
|
Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const {
|
||||||
|
return impl_->ForwardPreprocessor(std::move(audio));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value OfflineMoonshineModel::ForwardEncoder(
|
||||||
|
Ort::Value features, Ort::Value features_len) const {
|
||||||
|
return impl_->ForwardEncoder(std::move(features), std::move(features_len));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||||
|
OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token,
|
||||||
|
Ort::Value seq_len,
|
||||||
|
Ort::Value encoder_out) const {
|
||||||
|
return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len),
|
||||||
|
std::move(encoder_out));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||||
|
OfflineMoonshineModel::ForwardCachedDecoder(
|
||||||
|
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
|
||||||
|
std::vector<Ort::Value> states) const {
|
||||||
|
return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len),
|
||||||
|
std::move(encoder_out), std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineMoonshineModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
93
sherpa-onnx/csrc/offline-moonshine-model.h
Normal file
93
sherpa-onnx/csrc/offline-moonshine-model.h
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-moonshine-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// please see
|
||||||
|
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py
|
||||||
|
class OfflineMoonshineModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineMoonshineModel(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~OfflineMoonshineModel();
|
||||||
|
|
||||||
|
/** Run the preprocessor model.
|
||||||
|
*
|
||||||
|
* @param audio A float32 tensor of shape (batch_size, num_samples)
|
||||||
|
*
|
||||||
|
* @return Return a float32 tensor of shape (batch_size, T, dim) that
|
||||||
|
* can be used as the input of ForwardEncoder()
|
||||||
|
*/
|
||||||
|
Ort::Value ForwardPreprocessor(Ort::Value audio) const;
|
||||||
|
|
||||||
|
/** Run the encoder model.
|
||||||
|
*
|
||||||
|
* @param features A float32 tensor of shape (batch_size, T, dim)
|
||||||
|
* @param features_len A int32 tensor of shape (batch_size,)
|
||||||
|
* @returns A float32 tensor of shape (batch_size, T, dim).
|
||||||
|
*/
|
||||||
|
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const;
|
||||||
|
|
||||||
|
/** Run the uncached decoder.
|
||||||
|
*
|
||||||
|
* @param token A int32 tensor of shape (batch_size, num_tokens)
|
||||||
|
* @param seq_len A int32 tensor of shape (batch_size,) containing number
|
||||||
|
* of predicted tokens so far
|
||||||
|
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
|
||||||
|
*
|
||||||
|
* @returns Return a pair:
|
||||||
|
*
|
||||||
|
* - logits, a float32 tensor of shape (batch_size, 1, dim)
|
||||||
|
* - states, a list of states
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
|
||||||
|
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const;
|
||||||
|
|
||||||
|
/** Run the cached decoder.
|
||||||
|
*
|
||||||
|
* @param token A int32 tensor of shape (batch_size, num_tokens)
|
||||||
|
* @param seq_len A int32 tensor of shape (batch_size,) containing number
|
||||||
|
* of predicted tokens so far
|
||||||
|
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
|
||||||
|
* @param states A list of previous states
|
||||||
|
*
|
||||||
|
* @returns Return a pair:
|
||||||
|
* - logits, a float32 tensor of shape (batch_size, 1, dim)
|
||||||
|
* - states, a list of new states
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
|
||||||
|
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
|
||||||
|
std::vector<Ort::Value> states) const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
|
||||||
@@ -20,6 +20,7 @@
|
|||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||||
@@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.moonshine.preprocessor.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||||
// following models:
|
// following models:
|
||||||
// 1. transducer and nemo_transducer
|
// 1. transducer and nemo_transducer
|
||||||
@@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
model_type == "telespeech_ctc") {
|
model_type == "telespeech_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
} else if (model_type == "whisper") {
|
} else if (model_type == "whisper") {
|
||||||
|
// unreachable
|
||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
|
} else if (model_type == "moonshine") {
|
||||||
|
// unreachable
|
||||||
|
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||||
@@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.moonshine.preprocessor.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||||
// following models:
|
// following models:
|
||||||
// 1. transducer and nemo_transducer
|
// 1. transducer and nemo_transducer
|
||||||
@@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||||
} else if (model_type == "whisper") {
|
} else if (model_type == "whisper") {
|
||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||||
|
} else if (model_type == "moonshine") {
|
||||||
|
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||||
|
|||||||
150
sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
Normal file
150
sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OfflineRecognitionResult Convert(
|
||||||
|
const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) {
|
||||||
|
OfflineRecognitionResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
for (auto i : src.tokens) {
|
||||||
|
if (!sym_table.Contains(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &s = sym_table[i];
|
||||||
|
text += s;
|
||||||
|
r.tokens.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
r.text = text;
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config)
|
||||||
|
: OfflineRecognizerImpl(config),
|
||||||
|
config_(config),
|
||||||
|
symbol_table_(config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineMoonshineModel>(config.model_config)) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineRecognizerMoonshineImpl(AAssetManager *mgr,
|
||||||
|
const OfflineRecognizerConfig &config)
|
||||||
|
: OfflineRecognizerImpl(mgr, config),
|
||||||
|
config_(config),
|
||||||
|
symbol_table_(mgr, config_.model_config.tokens),
|
||||||
|
model_(
|
||||||
|
std::make_unique<OfflineMoonshineModel>(mgr, config.model_config)) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void Init() {
|
||||||
|
if (config_.decoding_method == "greedy_search") {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OfflineMoonshineGreedySearchDecoder>(model_.get());
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Only greedy_search is supported at present for moonshine. Given %s",
|
||||||
|
config_.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
MoonshineTag tag;
|
||||||
|
return std::make_unique<OfflineStream>(tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
// batch decoding is not implemented yet
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
DecodeStream(ss[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void DecodeStream(OfflineStream *s) const {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::vector<float> audio = s->GetFrames();
|
||||||
|
|
||||||
|
try {
|
||||||
|
std::array<int64_t, 2> shape{1, static_cast<int64_t>(audio.size())};
|
||||||
|
|
||||||
|
Ort::Value audio_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, audio.data(), audio.size(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
Ort::Value features =
|
||||||
|
model_->ForwardPreprocessor(std::move(audio_tensor));
|
||||||
|
|
||||||
|
int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
|
||||||
|
int64_t features_shape = 1;
|
||||||
|
|
||||||
|
Ort::Value features_len_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &features_len, 1, &features_shape, 1);
|
||||||
|
|
||||||
|
Ort::Value encoder_out = model_->ForwardEncoder(
|
||||||
|
std::move(features), std::move(features_len_tensor));
|
||||||
|
|
||||||
|
auto results = decoder_->Decode(std::move(encoder_out));
|
||||||
|
|
||||||
|
auto r = Convert(results[0], symbol_table_);
|
||||||
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
s->SetResult(r);
|
||||||
|
} catch (const Ort::Exception &ex) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
|
||||||
|
"audio samples: %d",
|
||||||
|
ex.what(), static_cast<int32_t>(audio.size()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineRecognizerConfig config_;
|
||||||
|
SymbolTable symbol_table_;
|
||||||
|
std::unique_ptr<OfflineMoonshineModel> model_;
|
||||||
|
std::unique_ptr<OfflineMoonshineDecoder> decoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
|
||||||
@@ -133,6 +133,10 @@ class OfflineStream::Impl {
|
|||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) {
|
||||||
|
config_.sampling_rate = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
if (config_.normalize_samples) {
|
if (config_.normalize_samples) {
|
||||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||||
@@ -164,7 +168,9 @@ class OfflineStream::Impl {
|
|||||||
std::vector<float> samples;
|
std::vector<float> samples;
|
||||||
resampler->Resample(waveform, n, true, &samples);
|
resampler->Resample(waveform, n, true, &samples);
|
||||||
|
|
||||||
if (fbank_) {
|
if (is_moonshine_) {
|
||||||
|
samples_.insert(samples_.end(), samples.begin(), samples.end());
|
||||||
|
} else if (fbank_) {
|
||||||
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
|
||||||
samples.size());
|
samples.size());
|
||||||
fbank_->InputFinished();
|
fbank_->InputFinished();
|
||||||
@@ -181,7 +187,9 @@ class OfflineStream::Impl {
|
|||||||
return;
|
return;
|
||||||
} // if (sampling_rate != config_.sampling_rate)
|
} // if (sampling_rate != config_.sampling_rate)
|
||||||
|
|
||||||
if (fbank_) {
|
if (is_moonshine_) {
|
||||||
|
samples_.insert(samples_.end(), waveform, waveform + n);
|
||||||
|
} else if (fbank_) {
|
||||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
fbank_->InputFinished();
|
fbank_->InputFinished();
|
||||||
} else if (mfcc_) {
|
} else if (mfcc_) {
|
||||||
@@ -194,10 +202,18 @@ class OfflineStream::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t FeatureDim() const {
|
int32_t FeatureDim() const {
|
||||||
|
if (is_moonshine_) {
|
||||||
|
return samples_.size();
|
||||||
|
}
|
||||||
|
|
||||||
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> GetFrames() const {
|
std::vector<float> GetFrames() const {
|
||||||
|
if (is_moonshine_) {
|
||||||
|
return samples_;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t n = fbank_ ? fbank_->NumFramesReady()
|
int32_t n = fbank_ ? fbank_->NumFramesReady()
|
||||||
: mfcc_ ? mfcc_->NumFramesReady()
|
: mfcc_ ? mfcc_->NumFramesReady()
|
||||||
: whisper_fbank_->NumFramesReady();
|
: whisper_fbank_->NumFramesReady();
|
||||||
@@ -300,6 +316,10 @@ class OfflineStream::Impl {
|
|||||||
OfflineRecognitionResult r_;
|
OfflineRecognitionResult r_;
|
||||||
ContextGraphPtr context_graph_;
|
ContextGraphPtr context_graph_;
|
||||||
bool is_ced_ = false;
|
bool is_ced_ = false;
|
||||||
|
bool is_moonshine_ = false;
|
||||||
|
|
||||||
|
// used only when is_moonshine_== true
|
||||||
|
std::vector<float> samples_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||||
@@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag)
|
|||||||
|
|
||||||
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
|
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
|
||||||
|
|
||||||
|
OfflineStream::OfflineStream(MoonshineTag tag)
|
||||||
|
: impl_(std::make_unique<Impl>(tag)) {}
|
||||||
|
|
||||||
OfflineStream::~OfflineStream() = default;
|
OfflineStream::~OfflineStream() = default;
|
||||||
|
|
||||||
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ struct OfflineRecognitionResult {
|
|||||||
// event target of the audio.
|
// event target of the audio.
|
||||||
std::string event;
|
std::string event;
|
||||||
|
|
||||||
/// timestamps.size() == tokens.size()
|
/// timestamps.size() == tokens.size()
|
||||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||||
std::vector<float> timestamps;
|
std::vector<float> timestamps;
|
||||||
|
|
||||||
@@ -49,6 +49,10 @@ struct WhisperTag {
|
|||||||
|
|
||||||
struct CEDTag {};
|
struct CEDTag {};
|
||||||
|
|
||||||
|
// It uses a neural network model, a preprocessor, to convert
|
||||||
|
// audio samples to features
|
||||||
|
struct MoonshineTag {};
|
||||||
|
|
||||||
class OfflineStream {
|
class OfflineStream {
|
||||||
public:
|
public:
|
||||||
explicit OfflineStream(const FeatureExtractorConfig &config = {},
|
explicit OfflineStream(const FeatureExtractorConfig &config = {},
|
||||||
@@ -56,6 +60,7 @@ class OfflineStream {
|
|||||||
|
|
||||||
explicit OfflineStream(WhisperTag tag);
|
explicit OfflineStream(WhisperTag tag);
|
||||||
explicit OfflineStream(CEDTag tag);
|
explicit OfflineStream(CEDTag tag);
|
||||||
|
explicit OfflineStream(MoonshineTag tag);
|
||||||
~OfflineStream();
|
~OfflineStream();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -72,7 +77,10 @@ class OfflineStream {
|
|||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||||
int32_t n) const;
|
int32_t n) const;
|
||||||
|
|
||||||
/// Return feature dim of this extractor
|
/// Return feature dim of this extractor.
|
||||||
|
///
|
||||||
|
/// Note: if it is Moonshine, then it returns the number of audio samples
|
||||||
|
/// currently received.
|
||||||
int32_t FeatureDim() const;
|
int32_t FeatureDim() const;
|
||||||
|
|
||||||
// Get all the feature frames of this stream in a 1-D array, which is
|
// Get all the feature frames of this stream in a 1-D array, which is
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl {
|
|||||||
explicit Impl(const OfflineModelConfig &config)
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
: config_(config),
|
: config_(config),
|
||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
debug_(config.debug),
|
|
||||||
sess_opts_(GetSessionOptions(config)),
|
sess_opts_(GetSessionOptions(config)),
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
{
|
{
|
||||||
@@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl {
|
|||||||
explicit Impl(const SpokenLanguageIdentificationConfig &config)
|
explicit Impl(const SpokenLanguageIdentificationConfig &config)
|
||||||
: lid_config_(config),
|
: lid_config_(config),
|
||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
debug_(config_.debug),
|
|
||||||
sess_opts_(GetSessionOptions(config)),
|
sess_opts_(GetSessionOptions(config)),
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
{
|
{
|
||||||
@@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl {
|
|||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
sess_opts_(GetSessionOptions(config)),
|
sess_opts_(GetSessionOptions(config)),
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
debug_ = config_.debug;
|
|
||||||
{
|
{
|
||||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||||
InitEncoder(buf.data(), buf.size());
|
InitEncoder(buf.data(), buf.size());
|
||||||
@@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl {
|
|||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
sess_opts_(GetSessionOptions(config)),
|
sess_opts_(GetSessionOptions(config)),
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
debug_ = config_.debug;
|
|
||||||
{
|
{
|
||||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||||
InitEncoder(buf.data(), buf.size());
|
InitEncoder(buf.data(), buf.size());
|
||||||
@@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (debug_) {
|
if (config_.debug) {
|
||||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||||
GetID2Lang().at(lang_id).c_str());
|
GetID2Lang().at(lang_id).c_str());
|
||||||
}
|
}
|
||||||
@@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
|
|
||||||
// get meta data
|
// get meta data
|
||||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
if (debug_) {
|
if (config_.debug) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "---encoder---\n";
|
os << "---encoder---\n";
|
||||||
PrintModelMetadata(os, meta_data);
|
PrintModelMetadata(os, meta_data);
|
||||||
@@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl {
|
|||||||
private:
|
private:
|
||||||
OfflineModelConfig config_;
|
OfflineModelConfig config_;
|
||||||
SpokenLanguageIdentificationConfig lid_config_;
|
SpokenLanguageIdentificationConfig lid_config_;
|
||||||
bool debug_ = false;
|
|
||||||
Ort::Env env_;
|
Ort::Env env_;
|
||||||
Ort::SessionOptions sess_opts_;
|
Ort::SessionOptions sess_opts_;
|
||||||
Ort::AllocatorWithDefaultOptions allocator_;
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|||||||
@@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in
|
|||||||
--decoding-method=greedy_search \
|
--decoding-method=greedy_search \
|
||||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
(3) Whisper models
|
(3) Moonshine models
|
||||||
|
|
||||||
|
See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html
|
||||||
|
|
||||||
|
./bin/sherpa-onnx-offline \
|
||||||
|
--moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \
|
||||||
|
--moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \
|
||||||
|
--moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \
|
||||||
|
--moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \
|
||||||
|
--tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \
|
||||||
|
--num-threads=1 \
|
||||||
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
|
(4) Whisper models
|
||||||
|
|
||||||
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
|
||||||
|
|
||||||
@@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
|
|||||||
--num-threads=1 \
|
--num-threads=1 \
|
||||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
(4) NeMo CTC models
|
(5) NeMo CTC models
|
||||||
|
|
||||||
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
|
||||||
|
|
||||||
@@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm
|
|||||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
|
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
|
||||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
|
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
|
||||||
|
|
||||||
(5) TDNN CTC model for the yesno recipe from icefall
|
(6) TDNN CTC model for the yesno recipe from icefall
|
||||||
|
|
||||||
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
|
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const {
|
|||||||
|
|
||||||
// for byte-level BPE
|
// for byte-level BPE
|
||||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||||
|
//
|
||||||
|
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
|
||||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ set(srcs
|
|||||||
offline-ctc-fst-decoder-config.cc
|
offline-ctc-fst-decoder-config.cc
|
||||||
offline-lm-config.cc
|
offline-lm-config.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
|
offline-moonshine-model-config.cc
|
||||||
offline-nemo-enc-dec-ctc-model-config.cc
|
offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
offline-paraformer-model-config.cc
|
offline-paraformer-model-config.cc
|
||||||
offline-punctuation.cc
|
offline-punctuation.cc
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
|
||||||
@@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
PybindOfflineZipformerCtcModelConfig(m);
|
PybindOfflineZipformerCtcModelConfig(m);
|
||||||
PybindOfflineWenetCtcModelConfig(m);
|
PybindOfflineWenetCtcModelConfig(m);
|
||||||
PybindOfflineSenseVoiceModelConfig(m);
|
PybindOfflineSenseVoiceModelConfig(m);
|
||||||
|
PybindOfflineMoonshineModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
@@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
|
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
|
||||||
const OfflineZipformerCtcModelConfig &,
|
const OfflineZipformerCtcModelConfig &,
|
||||||
const OfflineWenetCtcModelConfig &,
|
const OfflineWenetCtcModelConfig &,
|
||||||
const OfflineSenseVoiceModelConfig &, const std::string &,
|
const OfflineSenseVoiceModelConfig &,
|
||||||
|
const OfflineMoonshineModelConfig &, const std::string &,
|
||||||
const std::string &, int32_t, bool, const std::string &,
|
const std::string &, int32_t, bool, const std::string &,
|
||||||
const std::string &, const std::string &, const std::string &>(),
|
const std::string &, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
@@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||||
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||||
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
||||||
|
py::arg("moonshine") = OfflineMoonshineModelConfig(),
|
||||||
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
||||||
py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("num_threads"), py::arg("debug") = false,
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
||||||
@@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||||
.def_readwrite("sense_voice", &PyClass::sense_voice)
|
.def_readwrite("sense_voice", &PyClass::sense_voice)
|
||||||
|
.def_readwrite("moonshine", &PyClass::moonshine)
|
||||||
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
|
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
|
|||||||
28
sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
Normal file
28
sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineMoonshineModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineMoonshineModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineMoonshineModelConfig")
|
||||||
|
.def(py::init<const std::string &, const std::string &,
|
||||||
|
const std::string &, const std::string &>(),
|
||||||
|
py::arg("preprocessor"), py::arg("encoder"),
|
||||||
|
py::arg("uncached_decoder"), py::arg("cached_decoder"))
|
||||||
|
.def_readwrite("preprocessor", &PyClass::preprocessor)
|
||||||
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
|
.def_readwrite("uncached_decoder", &PyClass::uncached_decoder)
|
||||||
|
.def_readwrite("cached_decoder", &PyClass::cached_decoder)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-moonshine-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-moonshine-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-moonshine-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineMoonshineModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
|
||||||
@@ -8,13 +8,14 @@ from _sherpa_onnx import (
|
|||||||
OfflineCtcFstDecoderConfig,
|
OfflineCtcFstDecoderConfig,
|
||||||
OfflineLMConfig,
|
OfflineLMConfig,
|
||||||
OfflineModelConfig,
|
OfflineModelConfig,
|
||||||
|
OfflineMoonshineModelConfig,
|
||||||
OfflineNemoEncDecCtcModelConfig,
|
OfflineNemoEncDecCtcModelConfig,
|
||||||
OfflineParaformerModelConfig,
|
OfflineParaformerModelConfig,
|
||||||
OfflineSenseVoiceModelConfig,
|
|
||||||
)
|
)
|
||||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
OfflineRecognizerConfig,
|
OfflineRecognizerConfig,
|
||||||
|
OfflineSenseVoiceModelConfig,
|
||||||
OfflineStream,
|
OfflineStream,
|
||||||
OfflineTdnnModelConfig,
|
OfflineTdnnModelConfig,
|
||||||
OfflineTransducerModelConfig,
|
OfflineTransducerModelConfig,
|
||||||
@@ -503,12 +504,12 @@ class OfflineRecognizer(object):
|
|||||||
e.g., tiny, tiny.en, base, base.en, etc.
|
e.g., tiny, tiny.en, base, base.en, etc.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_model:
|
encoder:
|
||||||
Path to the encoder model, e.g., tiny-encoder.onnx,
|
|
||||||
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
|
||||||
decoder_model:
|
|
||||||
Path to the encoder model, e.g., tiny-encoder.onnx,
|
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||||
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||||
|
decoder:
|
||||||
|
Path to the decoder model, e.g., tiny-decoder.onnx,
|
||||||
|
tiny-decoder.int8.onnx, tiny-decoder.ort, etc.
|
||||||
tokens:
|
tokens:
|
||||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
columns::
|
columns::
|
||||||
@@ -570,6 +571,87 @@ class OfflineRecognizer(object):
|
|||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_moonshine(
|
||||||
|
cls,
|
||||||
|
preprocessor: str,
|
||||||
|
encoder: str,
|
||||||
|
uncached_decoder: str,
|
||||||
|
cached_decoder: str,
|
||||||
|
tokens: str,
|
||||||
|
num_threads: int = 1,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
|
rule_fsts: str = "",
|
||||||
|
rule_fars: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html>`_
|
||||||
|
to download pre-trained models for different kinds of moonshine models,
|
||||||
|
e.g., tiny, base, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preprocessor:
|
||||||
|
Path to the preprocessor model, e.g., preprocess.onnx
|
||||||
|
encoder:
|
||||||
|
Path to the encoder model, e.g., encode.int8.onnx
|
||||||
|
uncached_decoder:
|
||||||
|
Path to the uncached decoder model, e.g., uncached_decode.int8.onnx,
|
||||||
|
cached_decoder:
|
||||||
|
Path to the cached decoder model, e.g., cached_decode.int8.onnx,
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
decoding_method:
|
||||||
|
Valid values: greedy_search.
|
||||||
|
debug:
|
||||||
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
|
rule_fsts:
|
||||||
|
If not empty, it specifies fsts for inverse text normalization.
|
||||||
|
If there are multiple fsts, they are separated by a comma.
|
||||||
|
rule_fars:
|
||||||
|
If not empty, it specifies fst archives for inverse text normalization.
|
||||||
|
If there are multiple archives, they are separated by a comma.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
model_config = OfflineModelConfig(
|
||||||
|
moonshine=OfflineMoonshineModelConfig(
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
encoder=encoder,
|
||||||
|
uncached_decoder=uncached_decoder,
|
||||||
|
cached_decoder=cached_decoder,
|
||||||
|
),
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_feat_config = FeatureExtractorConfig(
|
||||||
|
sampling_rate=16000,
|
||||||
|
feature_dim=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OfflineRecognizerConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
feat_config=unused_feat_config,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
rule_fsts=rule_fsts,
|
||||||
|
rule_fars=rule_fars,
|
||||||
|
)
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
self.config = recognizer_config
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tdnn_ctc(
|
def from_tdnn_ctc(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
Reference in New Issue
Block a user