Support spoken language identification with whisper (#694)
This commit is contained in:
98
.github/scripts/test-spoken-language-identification.sh
vendored
Executable file
98
.github/scripts/test-spoken-language-identification.sh
vendored
Executable file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
names=(
|
||||
tiny
|
||||
base
|
||||
small
|
||||
medium
|
||||
)
|
||||
|
||||
# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi
|
||||
|
||||
log "Download test waves"
|
||||
waves=(
|
||||
ar-arabic.wav
|
||||
bg-bulgarian.wav
|
||||
cs-czech.wav
|
||||
da-danish.wav
|
||||
de-german.wav
|
||||
el-greek.wav
|
||||
en-english.wav
|
||||
es-spanish.wav
|
||||
fa-persian.wav
|
||||
fi-finnish.wav
|
||||
fr-french.wav
|
||||
hi-hindi.wav
|
||||
hr-croatian.wav
|
||||
id-indonesian.wav
|
||||
it-italian.wav
|
||||
ja-japanese.wav
|
||||
ko-korean.wav
|
||||
nl-dutch.wav
|
||||
no-norwegian.wav
|
||||
po-polish.wav
|
||||
pt-portuguese.wav
|
||||
ro-romanian.wav
|
||||
ru-russian.wav
|
||||
sk-slovak.wav
|
||||
sv-swedish.wav
|
||||
ta-tamil.wav
|
||||
tl-tagalog.wav
|
||||
tr-turkish.wav
|
||||
uk-ukrainian.wav
|
||||
zh-chinese.wav
|
||||
)
|
||||
|
||||
for wav in ${waves[@]}; do
|
||||
echo "Downloading $wav"
|
||||
curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/$wav
|
||||
ls -lh *.wav
|
||||
done
|
||||
|
||||
for name in ${names[@]}; do
|
||||
log "------------------------------------------------------------"
|
||||
log "Run $name"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
# git lfs pull --include "*.ort"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
for wav in ${waves[@]}; do
|
||||
log "test fp32 onnx"
|
||||
|
||||
time $EXE \
|
||||
--whisper-encoder=$repo/${name}-encoder.onnx \
|
||||
--whisper-decoder=$repo/${name}-decoder.onnx \
|
||||
$wav
|
||||
|
||||
log "test int8 onnx"
|
||||
|
||||
time $EXE \
|
||||
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
|
||||
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
|
||||
$wav
|
||||
done
|
||||
rm -rf $repo
|
||||
done
|
||||
1
.github/workflows/build-wheels-linux.yaml
vendored
1
.github/workflows/build-wheels-linux.yaml
vendored
@@ -82,7 +82,6 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
shell: bash
|
||||
with:
|
||||
max_attempts: 20
|
||||
timeout_seconds: 200
|
||||
|
||||
17
.github/workflows/build-wheels-macos-arm64.yaml
vendored
17
.github/workflows/build-wheels-macos-arm64.yaml
vendored
@@ -21,27 +21,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"]
|
||||
python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# see https://cibuildwheel.readthedocs.io/en/stable/changelog/
|
||||
# for a list of versions
|
||||
- name: Build wheels
|
||||
if: matrix.python-version == 'cp37'
|
||||
uses: pypa/cibuildwheel@v2.11.4
|
||||
env:
|
||||
CIBW_BUILD: "${{ matrix.python-version}}-* "
|
||||
CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'"
|
||||
CIBW_ARCHS: "arm64"
|
||||
CIBW_BUILD_VERBOSITY: 3
|
||||
|
||||
# Don't repair macOS wheels
|
||||
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.python-version != 'cp37'
|
||||
uses: pypa/cibuildwheel@v2.15.0
|
||||
env:
|
||||
CIBW_BUILD: "${{ matrix.python-version}}-* "
|
||||
|
||||
9
.github/workflows/linux-gpu.yaml
vendored
9
.github/workflows/linux-gpu.yaml
vendored
@@ -92,6 +92,14 @@ jobs:
|
||||
file build/bin/sherpa-onnx
|
||||
readelf -d build/bin/sherpa-onnx
|
||||
|
||||
- name: Test spoken language identification
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-language-identification
|
||||
|
||||
.github/scripts/test-spoken-language-identification.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -116,6 +124,7 @@ jobs:
|
||||
|
||||
.github/scripts/test-online-paraformer.sh
|
||||
|
||||
|
||||
- name: Test offline Whisper
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -123,6 +123,15 @@ jobs:
|
||||
name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }}
|
||||
path: build/bin/*
|
||||
|
||||
- name: Test spoken language identification
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-language-identification
|
||||
|
||||
.github/scripts/test-spoken-language-identification.sh
|
||||
|
||||
- name: Test transducer kws
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -140,6 +149,7 @@ jobs:
|
||||
.github/scripts/test-online-ctc.sh
|
||||
|
||||
- name: Test offline Whisper
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
|
||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -102,6 +102,15 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test spoken language identification
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-language-identification
|
||||
|
||||
.github/scripts/test-spoken-language-identification.sh
|
||||
|
||||
- name: Test transducer kws
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -135,6 +144,7 @@ jobs:
|
||||
.github/scripts/test-online-paraformer.sh
|
||||
|
||||
- name: Test offline Whisper
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
|
||||
8
.github/workflows/windows-x64-cuda.yaml
vendored
8
.github/workflows/windows-x64-cuda.yaml
vendored
@@ -68,6 +68,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test spoken language identification
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-language-identification.exe
|
||||
|
||||
.github/scripts/test-spoken-language-identification.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x64.yaml
vendored
8
.github/workflows/windows-x64.yaml
vendored
@@ -68,6 +68,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test spoken language identification
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-language-identification.exe
|
||||
|
||||
.github/scripts/test-spoken-language-identification.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x86.yaml
vendored
8
.github/workflows/windows-x86.yaml
vendored
@@ -69,6 +69,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
# - name: Test spoken language identification
|
||||
# shell: bash
|
||||
# run: |
|
||||
# export PATH=$PWD/build/bin/Release:$PATH
|
||||
# export EXE=sherpa-onnx-offline-language-identification.exe
|
||||
#
|
||||
# .github/scripts/test-spoken-language-identification.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||
project(sherpa-onnx)
|
||||
|
||||
set(SHERPA_ONNX_VERSION "1.9.13")
|
||||
set(SHERPA_ONNX_VERSION "1.9.14")
|
||||
|
||||
# Disable warning about
|
||||
#
|
||||
|
||||
@@ -43,6 +43,50 @@ def enable_alsa():
|
||||
return build_alsa and is_linux() and (is_arm64() or is_x86())
|
||||
|
||||
|
||||
def get_binaries():
|
||||
binaries = [
|
||||
"sherpa-onnx",
|
||||
"sherpa-onnx-keyword-spotter",
|
||||
"sherpa-onnx-microphone",
|
||||
"sherpa-onnx-microphone-offline",
|
||||
"sherpa-onnx-microphone-offline-speaker-identification",
|
||||
"sherpa-onnx-offline",
|
||||
"sherpa-onnx-offline-language-identification",
|
||||
"sherpa-onnx-offline-tts",
|
||||
"sherpa-onnx-offline-tts-play",
|
||||
"sherpa-onnx-offline-websocket-server",
|
||||
"sherpa-onnx-online-websocket-client",
|
||||
"sherpa-onnx-online-websocket-server",
|
||||
"sherpa-onnx-vad-microphone",
|
||||
"sherpa-onnx-vad-microphone-offline-asr",
|
||||
]
|
||||
|
||||
if enable_alsa():
|
||||
binaries += [
|
||||
"sherpa-onnx-alsa",
|
||||
"sherpa-onnx-alsa-offline",
|
||||
"sherpa-onnx-alsa-offline-speaker-identification",
|
||||
"sherpa-onnx-offline-tts-play-alsa",
|
||||
]
|
||||
|
||||
if is_windows():
|
||||
binaries += [
|
||||
"espeak-ng.dll",
|
||||
"kaldi-decoder-core.dll",
|
||||
"kaldi-native-fbank-core.dll",
|
||||
"onnxruntime.dll",
|
||||
"piper_phonemize.dll",
|
||||
"sherpa-onnx-c-api.dll",
|
||||
"sherpa-onnx-core.dll",
|
||||
"sherpa-onnx-fst.lib",
|
||||
"sherpa-onnx-kaldifst-core.lib",
|
||||
"sherpa-onnx-portaudio.dll",
|
||||
"ucd.dll",
|
||||
]
|
||||
|
||||
return binaries
|
||||
|
||||
|
||||
try:
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
@@ -150,38 +194,7 @@ class BuildExtension(build_ext):
|
||||
suffix = ".exe" if is_windows() else ""
|
||||
# Remember to also change setup.py
|
||||
|
||||
binaries = ["sherpa-onnx"]
|
||||
binaries += ["sherpa-onnx-keyword-spotter"]
|
||||
binaries += ["sherpa-onnx-offline"]
|
||||
binaries += ["sherpa-onnx-microphone"]
|
||||
binaries += ["sherpa-onnx-microphone-offline"]
|
||||
binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]
|
||||
binaries += ["sherpa-onnx-online-websocket-server"]
|
||||
binaries += ["sherpa-onnx-offline-websocket-server"]
|
||||
binaries += ["sherpa-onnx-online-websocket-client"]
|
||||
binaries += ["sherpa-onnx-vad-microphone"]
|
||||
binaries += ["sherpa-onnx-vad-microphone-offline-asr"]
|
||||
binaries += ["sherpa-onnx-offline-tts"]
|
||||
binaries += ["sherpa-onnx-offline-tts-play"]
|
||||
|
||||
if enable_alsa():
|
||||
binaries += ["sherpa-onnx-alsa"]
|
||||
binaries += ["sherpa-onnx-alsa-offline"]
|
||||
binaries += ["sherpa-onnx-offline-tts-play-alsa"]
|
||||
binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]
|
||||
|
||||
if is_windows():
|
||||
binaries += ["kaldi-native-fbank-core.dll"]
|
||||
binaries += ["sherpa-onnx-c-api.dll"]
|
||||
binaries += ["sherpa-onnx-core.dll"]
|
||||
binaries += ["sherpa-onnx-portaudio.dll"]
|
||||
binaries += ["onnxruntime.dll"]
|
||||
binaries += ["piper_phonemize.dll"]
|
||||
binaries += ["espeak-ng.dll"]
|
||||
binaries += ["ucd.dll"]
|
||||
binaries += ["kaldi-decoder-core.dll"]
|
||||
binaries += ["sherpa-onnx-fst.lib"]
|
||||
binaries += ["sherpa-onnx-kaldifst-core.lib"]
|
||||
binaries = get_binaries()
|
||||
|
||||
for f in binaries:
|
||||
suffix = "" if (".dll" in f or ".lib" in f) else suffix
|
||||
|
||||
172
python-api-examples/spoken-language-identification.py
Executable file
172
python-api-examples/spoken-language-identification.py
Executable file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script shows how to use Python APIs for spoken languge identification.
|
||||
It detects the language spoken in the given wave file.
|
||||
|
||||
Usage:
|
||||
|
||||
1. Download a whisper multilingual model. We use a tiny model below.
|
||||
Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
to download more models.
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
|
||||
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
|
||||
rm sherpa-onnx-whisper-tiny.tar.bz2
|
||||
|
||||
We only use the int8.onnx models below.
|
||||
|
||||
2. Download a test wave.
|
||||
|
||||
You can find many wave files for different languages at
|
||||
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
|
||||
|
||||
wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav
|
||||
|
||||
python3 ./python-api-examples/spoken-language-identification.py
|
||||
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
|
||||
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
|
||||
--num-threads=1 \
|
||||
./de-german.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-encoder",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to a multilingual whisper encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-decoder",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to a multilingual whisper decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="True to show debug messages",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to identify. It must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
config = sherpa_onnx.SpokenLanguageIdentificationConfig(
|
||||
whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig(
|
||||
encoder=args.whisper_encoder,
|
||||
decoder=args.whisper_decoder,
|
||||
),
|
||||
num_threads=args.num_threads,
|
||||
debug=args.debug,
|
||||
provider=args.provider,
|
||||
)
|
||||
slid = sherpa_onnx.SpokenLanguageIdentification(config)
|
||||
|
||||
samples, sample_rate = read_wave(args.sound_file)
|
||||
|
||||
start_time = time.time()
|
||||
stream = slid.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
|
||||
lang = slid.compute(stream)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
audio_duration = len(samples) / sample_rate
|
||||
real_time_factor = elapsed_seconds / audio_duration
|
||||
|
||||
logging.info(f"File: {args.sound_file}")
|
||||
logging.info(f"Detected language: {lang}")
|
||||
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||
logging.info(
|
||||
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
38
setup.py
38
setup.py
@@ -1,8 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import setuptools
|
||||
@@ -11,7 +9,7 @@ from cmake.cmake_extension import (
|
||||
BuildExtension,
|
||||
bdist_wheel,
|
||||
cmake_extension,
|
||||
enable_alsa,
|
||||
get_binaries,
|
||||
is_windows,
|
||||
)
|
||||
|
||||
@@ -42,39 +40,7 @@ def get_binaries_to_install():
|
||||
bin_dir.mkdir(parents=True, exist_ok=True)
|
||||
suffix = ".exe" if is_windows() else ""
|
||||
|
||||
# Remember to also change cmake/cmake_extension.py
|
||||
binaries = ["sherpa-onnx"]
|
||||
binaries += ["sherpa-onnx-keyword-spotter"]
|
||||
binaries += ["sherpa-onnx-offline"]
|
||||
binaries += ["sherpa-onnx-microphone"]
|
||||
binaries += ["sherpa-onnx-microphone-offline"]
|
||||
binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]
|
||||
binaries += ["sherpa-onnx-online-websocket-server"]
|
||||
binaries += ["sherpa-onnx-offline-websocket-server"]
|
||||
binaries += ["sherpa-onnx-online-websocket-client"]
|
||||
binaries += ["sherpa-onnx-vad-microphone"]
|
||||
binaries += ["sherpa-onnx-vad-microphone-offline-asr"]
|
||||
binaries += ["sherpa-onnx-offline-tts"]
|
||||
binaries += ["sherpa-onnx-offline-tts-play"]
|
||||
|
||||
if enable_alsa():
|
||||
binaries += ["sherpa-onnx-alsa"]
|
||||
binaries += ["sherpa-onnx-alsa-offline"]
|
||||
binaries += ["sherpa-onnx-offline-tts-play-alsa"]
|
||||
binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]
|
||||
|
||||
if is_windows():
|
||||
binaries += ["kaldi-native-fbank-core.dll"]
|
||||
binaries += ["sherpa-onnx-c-api.dll"]
|
||||
binaries += ["sherpa-onnx-core.dll"]
|
||||
binaries += ["sherpa-onnx-portaudio.dll"]
|
||||
binaries += ["onnxruntime.dll"]
|
||||
binaries += ["piper_phonemize.dll"]
|
||||
binaries += ["espeak-ng.dll"]
|
||||
binaries += ["ucd.dll"]
|
||||
binaries += ["kaldi-decoder-core.dll"]
|
||||
binaries += ["sherpa-onnx-fst.lib"]
|
||||
binaries += ["sherpa-onnx-kaldifst-core.lib"]
|
||||
binaries = get_binaries()
|
||||
|
||||
exe = []
|
||||
for f in binaries:
|
||||
|
||||
@@ -86,6 +86,8 @@ set(sources
|
||||
silero-vad-model-config.cc
|
||||
silero-vad-model.cc
|
||||
slice.cc
|
||||
spoken-language-identification-impl.cc
|
||||
spoken-language-identification.cc
|
||||
stack.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
@@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
|
||||
set(main_exes
|
||||
sherpa-onnx
|
||||
@@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-tts
|
||||
sherpa-onnx-offline-language-identification
|
||||
)
|
||||
|
||||
foreach(exe IN LISTS main_exes)
|
||||
|
||||
@@ -23,7 +23,7 @@ enum class ModelType {
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kWenetCtc,
|
||||
kUnkown,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"run.sh\n"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||
@@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kWenetCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
const OfflineModelConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
std::string filename;
|
||||
if (!config.nemo_ctc.model.empty()) {
|
||||
@@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
}
|
||||
@@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
|
||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
std::string filename;
|
||||
if (!config.nemo_ctc.model.empty()) {
|
||||
@@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
num_frames = max_num_frames - 50;
|
||||
}
|
||||
|
||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
// note that 1000 is an experience-value.
|
||||
// You can replace 1000 by other values, say, 100.
|
||||
@@ -162,38 +162,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim) {
|
||||
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
// mel = (log_spec + 4.0) / 4.0
|
||||
|
||||
int32_t n = num_frames * feat_dim;
|
||||
float max_v = -1e20;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
|
||||
f = std::max<float>(f, 1e-10);
|
||||
f = std::log10(f);
|
||||
|
||||
max_v = std::max(f, max_v);
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
|
||||
max_v -= 8;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
f = std::max(f, max_v);
|
||||
|
||||
f = (f + 4) / 4;
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
|
||||
@@ -12,56 +12,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
||||
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
||||
int64_t token_val = model_->SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
model_->GetID2Lang().at(lang_id).c_str());
|
||||
#endif
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult>
|
||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) {
|
||||
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
} else {
|
||||
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
||||
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
|
||||
@@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) override;
|
||||
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) const; // NOLINT
|
||||
|
||||
private:
|
||||
OfflineWhisperModelConfig config_;
|
||||
OfflineWhisperModel *model_; // not owned
|
||||
|
||||
@@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register(
|
||||
"whisper-tail-paddings", &tail_paddings,
|
||||
"Suggest value: 50 for English models. 300 for multilingual models. "
|
||||
"Suggested value: 50 for English models. 300 for multilingual models. "
|
||||
"Since we have removed the 30-second constraint, we need to add some "
|
||||
"tail padding frames "
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
|
||||
"English models and 300 for multilingual models.");
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 1000.");
|
||||
}
|
||||
|
||||
bool OfflineWhisperModelConfig::Validate() const {
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||
return false;
|
||||
|
||||
@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
explicit Impl(const SpokenLanguageIdentificationConfig &config)
|
||||
: lid_config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
|
||||
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
||||
}
|
||||
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) { // NOLINT
|
||||
int64_t token_val = SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out =
|
||||
ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k),
|
||||
std::move(cross_v), std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = VocabSize();
|
||||
const auto &all_language_ids = GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug_) {
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
GetID2Lang().at(lang_id).c_str());
|
||||
}
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
||||
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
|
||||
|
||||
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
if (debug_) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
SpokenLanguageIdentificationConfig lid_config_;
|
||||
bool debug_ = false;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
|
||||
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineWhisperModel::OfflineWhisperModel(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
||||
std::move(n_layer_cross_v), std::move(offset));
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) { // NOLINT
|
||||
return impl_->DetectLanguage(cross_k, cross_v);
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
||||
const {
|
||||
return impl_->GetInitialSelfKVCache();
|
||||
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
|
||||
return impl_->IsMultiLingual();
|
||||
}
|
||||
|
||||
void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim) {
|
||||
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
// mel = (log_spec + 4.0) / 4.0
|
||||
|
||||
int32_t n = num_frames * feat_dim;
|
||||
float max_v = -1e20;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
|
||||
f = std::max<float>(f, 1e-10);
|
||||
f = std::log10(f);
|
||||
|
||||
max_v = std::max(f, max_v);
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
|
||||
max_v -= 8;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
f = std::max(f, max_v);
|
||||
|
||||
f = (f + 4) / 4;
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -25,6 +26,9 @@ class OfflineWhisperModel {
|
||||
public:
|
||||
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
||||
|
||||
explicit OfflineWhisperModel(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
@@ -72,7 +76,8 @@ class OfflineWhisperModel {
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||
|
||||
int32_t DetectLanguage() const;
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v); // NOLINT
|
||||
|
||||
/** Return the initial self kv cache in a pair
|
||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||
@@ -98,6 +103,9 @@ class OfflineWhisperModel {
|
||||
int32_t Translate() const;
|
||||
bool IsMultiLingual() const;
|
||||
|
||||
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim);
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -28,7 +28,7 @@ enum class ModelType {
|
||||
kLstm,
|
||||
kZipformer,
|
||||
kZipformer2,
|
||||
kUnkown,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||
"to export your transducer models");
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("conformer")) {
|
||||
@@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kZipformer2;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.transducer.encoder);
|
||||
@@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||
case ModelType::kZipformer2:
|
||||
return std::make_unique<OnlineZipformer2TransducerModel>(config);
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||
return nullptr;
|
||||
}
|
||||
@@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||
case ModelType::kZipformer2:
|
||||
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
107
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
Normal file
107
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
Normal file
@@ -0,0 +1,107 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Spoken language identification with sherpa-onnx.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Use a whisper multilingual model
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
|
||||
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
|
||||
rm sherpa-onnx-whisper-tiny.tar.bz2
|
||||
|
||||
We only use the int8.onnx models below.
|
||||
|
||||
./bin/sherpa-onnx-offline-spoken-language-identification \
|
||||
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
|
||||
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
|
||||
--num-threads=1 \
|
||||
/path/to/foo.wav
|
||||
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
You can find test waves for different languages at
|
||||
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
|
||||
Note that only whisper multilingual models are supported. For instance,
|
||||
"tiny" is supported but "tiny.en" is not.
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::SpokenLanguageIdentificationConfig config;
|
||||
config.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 1) {
|
||||
fprintf(stderr, "Error: Please provide 1 wave file.\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Creating spoken language identifier ...\n");
|
||||
sherpa_onnx::SpokenLanguageIdentification slid(config);
|
||||
|
||||
fprintf(stderr, "Started\n");
|
||||
const std::string wav_filename = po.GetArg(1);
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
|
||||
auto s = slid.CreateStream();
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
auto language = slid.Compute(s.get());
|
||||
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
||||
fprintf(stderr, "Done!\n\n");
|
||||
fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(),
|
||||
language.c_str());
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.num_threads);
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -16,7 +16,7 @@ enum class ModelType {
|
||||
kWeSpeaker,
|
||||
k3dSpeaker,
|
||||
kNeMo,
|
||||
kUnkown,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
||||
"add_meta_data.py"
|
||||
"to add metadata to models from WeSpeaker\n");
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("wespeaker")) {
|
||||
@@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kNeMo;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||
SpeakerEmbeddingExtractorImpl::Create(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.model);
|
||||
@@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
|
||||
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
|
||||
case ModelType::kNeMo:
|
||||
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type in for speaker embedding extractor!");
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||
SpeakerEmbeddingExtractorImpl::Create(
|
||||
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(mgr, config.model);
|
||||
@@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
|
||||
config);
|
||||
case ModelType::kNeMo:
|
||||
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type in for speaker embedding extractor!");
|
||||
return nullptr;
|
||||
|
||||
88
sherpa-onnx/csrc/spoken-language-identification-impl.cc
Normal file
88
sherpa-onnx/csrc/spoken-language-identification-impl.cc
Normal file
@@ -0,0 +1,88 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kWhisper,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you have added metadata to the model.\n\n"
|
||||
"For instance, you can use\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
|
||||
"export-onnx.py "
|
||||
"to add metadata to models from whisper\n");
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
auto model_type_str = std::string(model_type.get());
|
||||
if (model_type_str.find("whisper") == 0) {
|
||||
return ModelType::kWhisper;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SpokenLanguageIdentificationImpl>
|
||||
SpokenLanguageIdentificationImpl::Create(
|
||||
const SpokenLanguageIdentificationConfig &config) {
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
{
|
||||
if (config.whisper.encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Only whisper models are supported at present");
|
||||
exit(-1);
|
||||
}
|
||||
auto buffer = ReadFile(config.whisper.encoder);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kWhisper:
|
||||
return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type for spoken language identification!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/spoken-language-identification-impl.h
Normal file
28
sherpa-onnx/csrc/spoken-language-identification-impl.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpokenLanguageIdentificationImpl {
|
||||
public:
|
||||
virtual ~SpokenLanguageIdentificationImpl() = default;
|
||||
|
||||
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::string Compute(OfflineStream *s) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||
119
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
Normal file
119
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
Normal file
@@ -0,0 +1,119 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpokenLanguageIdentificationWhisperImpl
|
||||
: public SpokenLanguageIdentificationImpl {
|
||||
public:
|
||||
explicit SpokenLanguageIdentificationWhisperImpl(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
|
||||
Check();
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(WhisperTag{});
|
||||
}
|
||||
|
||||
std::string Compute(OfflineStream *s) const override {
|
||||
int32_t max_num_frames = 3000;
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = s->FeatureDim();
|
||||
std::vector<float> f = s->GetFrames();
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
// we use 50 here so that there will be some zero tail paddings
|
||||
if (num_frames >= max_num_frames - 50) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only waves less than 30 seconds are supported. We process only the "
|
||||
"first 30 seconds and discard the remaining data");
|
||||
num_frames = max_num_frames - 50;
|
||||
}
|
||||
|
||||
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
// note that 1000 is an experience-value.
|
||||
// You can replace 1000 by other values, say, 100.
|
||||
//
|
||||
// Since we have removed the 30 seconds constraint, we need
|
||||
// tail_padding_frames so that whisper is able to detect the eot token.
|
||||
int32_t tail_padding_frames = 1000;
|
||||
|
||||
if (config_.whisper.tail_paddings > 0) {
|
||||
tail_padding_frames = config_.whisper.tail_paddings;
|
||||
}
|
||||
|
||||
int32_t actual_frames =
|
||||
std::min(num_frames + tail_padding_frames, max_num_frames);
|
||||
|
||||
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
|
||||
|
||||
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||
model_->Allocator(), shape.data(), shape.size());
|
||||
|
||||
float *p_mel = mel.GetTensorMutableData<float>();
|
||||
std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
|
||||
|
||||
std::fill_n(p_mel + num_frames * feat_dim,
|
||||
(actual_frames - num_frames) * feat_dim, 0);
|
||||
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
try {
|
||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||
int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
|
||||
const auto &id2lang = model_->GetID2Lang();
|
||||
if (id2lang.count(lang_id)) {
|
||||
return id2lang.at(lang_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
|
||||
lang_id);
|
||||
return "";
|
||||
}
|
||||
} catch (const Ort::Exception &ex) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
|
||||
"input frames: %d, Current tail "
|
||||
"paddings: %d. If you see a lot of such exceptions, please consider "
|
||||
"using a larger --whisper-tail-paddings",
|
||||
ex.what(), num_frames, tail_padding_frames);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Check() const {
|
||||
if (!model_->IsMultiLingual()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only whisper multilingual models can be used for spoken language "
|
||||
"identification. Given: %s,%s",
|
||||
config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
SpokenLanguageIdentificationConfig config_;
|
||||
std::unique_ptr<OfflineWhisperModel> model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||
117
sherpa-onnx/csrc/spoken-language-identification.cc
Normal file
117
sherpa-onnx/csrc/spoken-language-identification.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) {
|
||||
po->Register(
|
||||
"whisper-encoder", &encoder,
|
||||
"Path to then encoder of a whisper multilingual model. Support only "
|
||||
"tiny, base, small, medium, large.");
|
||||
|
||||
po->Register(
|
||||
"whisper-decoder", &decoder,
|
||||
"Path to the decoder of a whisper multilingual model. Support only "
|
||||
"tiny, base, small, medium, large.");
|
||||
|
||||
po->Register(
|
||||
"whisper-tail-paddings", &tail_paddings,
|
||||
"Suggested value: 300 for multilingual models. "
|
||||
"Since we have removed the 30-second constraint, we need to add some "
|
||||
"tail padding frames "
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 1000");
|
||||
}
|
||||
|
||||
bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SpokenLanguageIdentificationWhisperConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "SpokenLanguageIdentificationWhisperConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "tail_paddings=" << tail_paddings << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) {
|
||||
whisper.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool SpokenLanguageIdentificationConfig::Validate() const {
|
||||
if (!whisper.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SpokenLanguageIdentificationConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "SpokenLanguageIdentificationConfig(";
|
||||
os << "whisper=\"" << whisper.ToString() << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
SpokenLanguageIdentification::SpokenLanguageIdentification(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
|
||||
|
||||
SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
|
||||
const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const {
|
||||
return impl_->Compute(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
89
sherpa-onnx/csrc/spoken-language-identification.h
Normal file
89
sherpa-onnx/csrc/spoken-language-identification.h
Normal file
@@ -0,0 +1,89 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct SpokenLanguageIdentificationWhisperConfig {
|
||||
// Requires a multi-lingual whisper model.
|
||||
// That is, it supports only tiny, base, small, medium, large.
|
||||
// Note: It does NOT support tiny.en, base.en, small.en, medium.en
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
// Number of tail padding frames.
|
||||
//
|
||||
// Since we remove the 30-second constraint, we need to add some paddings
|
||||
// at the end.
|
||||
//
|
||||
// Recommended values:
|
||||
// - 50 for English models
|
||||
// - 300 for multilingual models
|
||||
int32_t tail_paddings = -1;
|
||||
|
||||
SpokenLanguageIdentificationWhisperConfig() = default;
|
||||
|
||||
SpokenLanguageIdentificationWhisperConfig(const std::string &encoder,
|
||||
const std::string &decoder,
|
||||
int32_t tail_paddings)
|
||||
: encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
struct SpokenLanguageIdentificationConfig {
|
||||
SpokenLanguageIdentificationWhisperConfig whisper;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
SpokenLanguageIdentificationConfig() = default;
|
||||
|
||||
SpokenLanguageIdentificationConfig(
|
||||
const SpokenLanguageIdentificationWhisperConfig &whisper,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
: whisper(whisper),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class SpokenLanguageIdentificationImpl;
|
||||
|
||||
class SpokenLanguageIdentification {
|
||||
public:
|
||||
explicit SpokenLanguageIdentification(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
~SpokenLanguageIdentification();
|
||||
|
||||
// Create a stream to accept audio samples and compute features
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
// Return a string containing the language, e.g., en, zh, de,
|
||||
// etc.
|
||||
// Note: en is for English, zh is for Chinese, de is for German, etc.
|
||||
std::string Compute(OfflineStream *s) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<SpokenLanguageIdentificationImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
@@ -33,6 +33,7 @@ set(srcs
|
||||
silero-vad-model-config.cc
|
||||
speaker-embedding-extractor.cc
|
||||
speaker-embedding-manager.cc
|
||||
spoken-language-identification.cc
|
||||
vad-model-config.cc
|
||||
vad-model.cc
|
||||
voice-activity-detector.cc
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/vad-model.h"
|
||||
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
||||
@@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOfflineTts(&m);
|
||||
PybindSpeakerEmbeddingExtractor(&m);
|
||||
PybindSpeakerEmbeddingManager(&m);
|
||||
PybindSpokenLanguageIdentification(&m);
|
||||
|
||||
PybindAlsa(&m);
|
||||
}
|
||||
|
||||
60
sherpa-onnx/python/csrc/spoken-language-identification.cc
Normal file
60
sherpa-onnx/python/csrc/spoken-language-identification.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// sherpa-onnx/python/csrc/spoken-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) {
|
||||
using PyClass = SpokenLanguageIdentificationWhisperConfig;
|
||||
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentificationWhisperConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, const std::string &, int32_t>(),
|
||||
py::arg("encoder"), py::arg("decoder"),
|
||||
py::arg("tail_paddings") = -1)
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
|
||||
PybindSpokenLanguageIdentificationWhisperConfig(m);
|
||||
|
||||
using PyClass = SpokenLanguageIdentificationConfig;
|
||||
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
|
||||
bool, const std::string>(),
|
||||
py::arg("whisper"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("whisper", &PyClass::whisper)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindSpokenLanguageIdentification(py::module *m) {
|
||||
PybindSpokenLanguageIdentificationConfig(m);
|
||||
|
||||
using PyClass = SpokenLanguageIdentification;
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentification")
|
||||
.def(py::init<const SpokenLanguageIdentificationConfig &>(),
|
||||
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("create_stream", &PyClass::CreateStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("compute", &PyClass::Compute,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/spoken-language-identification.h
Normal file
16
sherpa-onnx/python/csrc/spoken-language-identification.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/spoken-language-identification.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindSpokenLanguageIdentification(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
@@ -13,6 +13,9 @@ from _sherpa_onnx import (
|
||||
SpeakerEmbeddingExtractorConfig,
|
||||
SpeakerEmbeddingManager,
|
||||
SpeechSegment,
|
||||
SpokenLanguageIdentification,
|
||||
SpokenLanguageIdentificationConfig,
|
||||
SpokenLanguageIdentificationWhisperConfig,
|
||||
VadModel,
|
||||
VadModelConfig,
|
||||
VoiceActivityDetector,
|
||||
|
||||
Reference in New Issue
Block a user