Support spoken language identification with whisper (#694)
This commit is contained in:
98
.github/scripts/test-spoken-language-identification.sh
vendored
Executable file
98
.github/scripts/test-spoken-language-identification.sh
vendored
Executable file
@@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "EXE is $EXE"
|
||||||
|
echo "PATH: $PATH"
|
||||||
|
|
||||||
|
which $EXE
|
||||||
|
|
||||||
|
names=(
|
||||||
|
tiny
|
||||||
|
base
|
||||||
|
small
|
||||||
|
medium
|
||||||
|
)
|
||||||
|
|
||||||
|
# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi
|
||||||
|
|
||||||
|
log "Download test waves"
|
||||||
|
waves=(
|
||||||
|
ar-arabic.wav
|
||||||
|
bg-bulgarian.wav
|
||||||
|
cs-czech.wav
|
||||||
|
da-danish.wav
|
||||||
|
de-german.wav
|
||||||
|
el-greek.wav
|
||||||
|
en-english.wav
|
||||||
|
es-spanish.wav
|
||||||
|
fa-persian.wav
|
||||||
|
fi-finnish.wav
|
||||||
|
fr-french.wav
|
||||||
|
hi-hindi.wav
|
||||||
|
hr-croatian.wav
|
||||||
|
id-indonesian.wav
|
||||||
|
it-italian.wav
|
||||||
|
ja-japanese.wav
|
||||||
|
ko-korean.wav
|
||||||
|
nl-dutch.wav
|
||||||
|
no-norwegian.wav
|
||||||
|
po-polish.wav
|
||||||
|
pt-portuguese.wav
|
||||||
|
ro-romanian.wav
|
||||||
|
ru-russian.wav
|
||||||
|
sk-slovak.wav
|
||||||
|
sv-swedish.wav
|
||||||
|
ta-tamil.wav
|
||||||
|
tl-tagalog.wav
|
||||||
|
tr-turkish.wav
|
||||||
|
uk-ukrainian.wav
|
||||||
|
zh-chinese.wav
|
||||||
|
)
|
||||||
|
|
||||||
|
for wav in ${waves[@]}; do
|
||||||
|
echo "Downloading $wav"
|
||||||
|
curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/$wav
|
||||||
|
ls -lh *.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
for name in ${names[@]}; do
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run $name"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
# git lfs pull --include "*.ort"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
for wav in ${waves[@]}; do
|
||||||
|
log "test fp32 onnx"
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--whisper-encoder=$repo/${name}-encoder.onnx \
|
||||||
|
--whisper-decoder=$repo/${name}-decoder.onnx \
|
||||||
|
$wav
|
||||||
|
|
||||||
|
log "test int8 onnx"
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
|
||||||
|
$wav
|
||||||
|
done
|
||||||
|
rm -rf $repo
|
||||||
|
done
|
||||||
1
.github/workflows/build-wheels-linux.yaml
vendored
1
.github/workflows/build-wheels-linux.yaml
vendored
@@ -82,7 +82,6 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
uses: nick-fields/retry@v3
|
uses: nick-fields/retry@v3
|
||||||
shell: bash
|
|
||||||
with:
|
with:
|
||||||
max_attempts: 20
|
max_attempts: 20
|
||||||
timeout_seconds: 200
|
timeout_seconds: 200
|
||||||
|
|||||||
17
.github/workflows/build-wheels-macos-arm64.yaml
vendored
17
.github/workflows/build-wheels-macos-arm64.yaml
vendored
@@ -21,27 +21,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest]
|
os: [macos-latest]
|
||||||
python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"]
|
python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
# see https://cibuildwheel.readthedocs.io/en/stable/changelog/
|
|
||||||
# for a list of versions
|
|
||||||
- name: Build wheels
|
- name: Build wheels
|
||||||
if: matrix.python-version == 'cp37'
|
|
||||||
uses: pypa/cibuildwheel@v2.11.4
|
|
||||||
env:
|
|
||||||
CIBW_BUILD: "${{ matrix.python-version}}-* "
|
|
||||||
CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'"
|
|
||||||
CIBW_ARCHS: "arm64"
|
|
||||||
CIBW_BUILD_VERBOSITY: 3
|
|
||||||
|
|
||||||
# Don't repair macOS wheels
|
|
||||||
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
|
|
||||||
|
|
||||||
- name: Build wheels
|
|
||||||
if: matrix.python-version != 'cp37'
|
|
||||||
uses: pypa/cibuildwheel@v2.15.0
|
uses: pypa/cibuildwheel@v2.15.0
|
||||||
env:
|
env:
|
||||||
CIBW_BUILD: "${{ matrix.python-version}}-* "
|
CIBW_BUILD: "${{ matrix.python-version}}-* "
|
||||||
|
|||||||
9
.github/workflows/linux-gpu.yaml
vendored
9
.github/workflows/linux-gpu.yaml
vendored
@@ -92,6 +92,14 @@ jobs:
|
|||||||
file build/bin/sherpa-onnx
|
file build/bin/sherpa-onnx
|
||||||
readelf -d build/bin/sherpa-onnx
|
readelf -d build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test spoken language identification
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline-language-identification
|
||||||
|
|
||||||
|
.github/scripts/test-spoken-language-identification.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -116,6 +124,7 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-online-paraformer.sh
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
|
|
||||||
- name: Test offline Whisper
|
- name: Test offline Whisper
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -123,6 +123,15 @@ jobs:
|
|||||||
name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }}
|
name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }}
|
||||||
path: build/bin/*
|
path: build/bin/*
|
||||||
|
|
||||||
|
- name: Test spoken language identification
|
||||||
|
if: matrix.build_type != 'Debug'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline-language-identification
|
||||||
|
|
||||||
|
.github/scripts/test-spoken-language-identification.sh
|
||||||
|
|
||||||
- name: Test transducer kws
|
- name: Test transducer kws
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -140,6 +149,7 @@ jobs:
|
|||||||
.github/scripts/test-online-ctc.sh
|
.github/scripts/test-online-ctc.sh
|
||||||
|
|
||||||
- name: Test offline Whisper
|
- name: Test offline Whisper
|
||||||
|
if: matrix.build_type != 'Debug'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
export PATH=$PWD/build/bin:$PATH
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
|||||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -102,6 +102,15 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test spoken language identification
|
||||||
|
if: matrix.build_type != 'Debug'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline-language-identification
|
||||||
|
|
||||||
|
.github/scripts/test-spoken-language-identification.sh
|
||||||
|
|
||||||
- name: Test transducer kws
|
- name: Test transducer kws
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -135,6 +144,7 @@ jobs:
|
|||||||
.github/scripts/test-online-paraformer.sh
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper
|
- name: Test offline Whisper
|
||||||
|
if: matrix.build_type != 'Debug'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
export PATH=$PWD/build/bin:$PATH
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
|||||||
8
.github/workflows/windows-x64-cuda.yaml
vendored
8
.github/workflows/windows-x64-cuda.yaml
vendored
@@ -68,6 +68,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test spoken language identification
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline-language-identification.exe
|
||||||
|
|
||||||
|
.github/scripts/test-spoken-language-identification.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/windows-x64.yaml
vendored
8
.github/workflows/windows-x64.yaml
vendored
@@ -68,6 +68,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test spoken language identification
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline-language-identification.exe
|
||||||
|
|
||||||
|
.github/scripts/test-spoken-language-identification.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/windows-x86.yaml
vendored
8
.github/workflows/windows-x86.yaml
vendored
@@ -69,6 +69,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
# - name: Test spoken language identification
|
||||||
|
# shell: bash
|
||||||
|
# run: |
|
||||||
|
# export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
# export EXE=sherpa-onnx-offline-language-identification.exe
|
||||||
|
#
|
||||||
|
# .github/scripts/test-spoken-language-identification.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.9.13")
|
set(SHERPA_ONNX_VERSION "1.9.14")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -43,6 +43,50 @@ def enable_alsa():
|
|||||||
return build_alsa and is_linux() and (is_arm64() or is_x86())
|
return build_alsa and is_linux() and (is_arm64() or is_x86())
|
||||||
|
|
||||||
|
|
||||||
|
def get_binaries():
|
||||||
|
binaries = [
|
||||||
|
"sherpa-onnx",
|
||||||
|
"sherpa-onnx-keyword-spotter",
|
||||||
|
"sherpa-onnx-microphone",
|
||||||
|
"sherpa-onnx-microphone-offline",
|
||||||
|
"sherpa-onnx-microphone-offline-speaker-identification",
|
||||||
|
"sherpa-onnx-offline",
|
||||||
|
"sherpa-onnx-offline-language-identification",
|
||||||
|
"sherpa-onnx-offline-tts",
|
||||||
|
"sherpa-onnx-offline-tts-play",
|
||||||
|
"sherpa-onnx-offline-websocket-server",
|
||||||
|
"sherpa-onnx-online-websocket-client",
|
||||||
|
"sherpa-onnx-online-websocket-server",
|
||||||
|
"sherpa-onnx-vad-microphone",
|
||||||
|
"sherpa-onnx-vad-microphone-offline-asr",
|
||||||
|
]
|
||||||
|
|
||||||
|
if enable_alsa():
|
||||||
|
binaries += [
|
||||||
|
"sherpa-onnx-alsa",
|
||||||
|
"sherpa-onnx-alsa-offline",
|
||||||
|
"sherpa-onnx-alsa-offline-speaker-identification",
|
||||||
|
"sherpa-onnx-offline-tts-play-alsa",
|
||||||
|
]
|
||||||
|
|
||||||
|
if is_windows():
|
||||||
|
binaries += [
|
||||||
|
"espeak-ng.dll",
|
||||||
|
"kaldi-decoder-core.dll",
|
||||||
|
"kaldi-native-fbank-core.dll",
|
||||||
|
"onnxruntime.dll",
|
||||||
|
"piper_phonemize.dll",
|
||||||
|
"sherpa-onnx-c-api.dll",
|
||||||
|
"sherpa-onnx-core.dll",
|
||||||
|
"sherpa-onnx-fst.lib",
|
||||||
|
"sherpa-onnx-kaldifst-core.lib",
|
||||||
|
"sherpa-onnx-portaudio.dll",
|
||||||
|
"ucd.dll",
|
||||||
|
]
|
||||||
|
|
||||||
|
return binaries
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||||
|
|
||||||
@@ -150,38 +194,7 @@ class BuildExtension(build_ext):
|
|||||||
suffix = ".exe" if is_windows() else ""
|
suffix = ".exe" if is_windows() else ""
|
||||||
# Remember to also change setup.py
|
# Remember to also change setup.py
|
||||||
|
|
||||||
binaries = ["sherpa-onnx"]
|
binaries = get_binaries()
|
||||||
binaries += ["sherpa-onnx-keyword-spotter"]
|
|
||||||
binaries += ["sherpa-onnx-offline"]
|
|
||||||
binaries += ["sherpa-onnx-microphone"]
|
|
||||||
binaries += ["sherpa-onnx-microphone-offline"]
|
|
||||||
binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]
|
|
||||||
binaries += ["sherpa-onnx-online-websocket-server"]
|
|
||||||
binaries += ["sherpa-onnx-offline-websocket-server"]
|
|
||||||
binaries += ["sherpa-onnx-online-websocket-client"]
|
|
||||||
binaries += ["sherpa-onnx-vad-microphone"]
|
|
||||||
binaries += ["sherpa-onnx-vad-microphone-offline-asr"]
|
|
||||||
binaries += ["sherpa-onnx-offline-tts"]
|
|
||||||
binaries += ["sherpa-onnx-offline-tts-play"]
|
|
||||||
|
|
||||||
if enable_alsa():
|
|
||||||
binaries += ["sherpa-onnx-alsa"]
|
|
||||||
binaries += ["sherpa-onnx-alsa-offline"]
|
|
||||||
binaries += ["sherpa-onnx-offline-tts-play-alsa"]
|
|
||||||
binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]
|
|
||||||
|
|
||||||
if is_windows():
|
|
||||||
binaries += ["kaldi-native-fbank-core.dll"]
|
|
||||||
binaries += ["sherpa-onnx-c-api.dll"]
|
|
||||||
binaries += ["sherpa-onnx-core.dll"]
|
|
||||||
binaries += ["sherpa-onnx-portaudio.dll"]
|
|
||||||
binaries += ["onnxruntime.dll"]
|
|
||||||
binaries += ["piper_phonemize.dll"]
|
|
||||||
binaries += ["espeak-ng.dll"]
|
|
||||||
binaries += ["ucd.dll"]
|
|
||||||
binaries += ["kaldi-decoder-core.dll"]
|
|
||||||
binaries += ["sherpa-onnx-fst.lib"]
|
|
||||||
binaries += ["sherpa-onnx-kaldifst-core.lib"]
|
|
||||||
|
|
||||||
for f in binaries:
|
for f in binaries:
|
||||||
suffix = "" if (".dll" in f or ".lib" in f) else suffix
|
suffix = "" if (".dll" in f or ".lib" in f) else suffix
|
||||||
|
|||||||
172
python-api-examples/spoken-language-identification.py
Executable file
172
python-api-examples/spoken-language-identification.py
Executable file
@@ -0,0 +1,172 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script shows how to use Python APIs for spoken languge identification.
|
||||||
|
It detects the language spoken in the given wave file.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
1. Download a whisper multilingual model. We use a tiny model below.
|
||||||
|
Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
to download more models.
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
|
||||||
|
rm sherpa-onnx-whisper-tiny.tar.bz2
|
||||||
|
|
||||||
|
We only use the int8.onnx models below.
|
||||||
|
|
||||||
|
2. Download a test wave.
|
||||||
|
|
||||||
|
You can find many wave files for different languages at
|
||||||
|
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
|
||||||
|
|
||||||
|
wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav
|
||||||
|
|
||||||
|
python3 ./python-api-examples/spoken-language-identification.py
|
||||||
|
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
|
||||||
|
--num-threads=1 \
|
||||||
|
./de-german.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-encoder",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to a multilingual whisper encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-decoder",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to a multilingual whisper decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of threads for neural network computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="True to show debug messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Valid values: cpu, cuda, coreml",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_file",
|
||||||
|
type=str,
|
||||||
|
help="The input sound file to identify. It must be of WAVE"
|
||||||
|
"format with a single channel, and each sample has 16-bit, "
|
||||||
|
"i.e., int16_t. "
|
||||||
|
"The sample rate of the file can be arbitrary and does not need to "
|
||||||
|
"be 16 kHz",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_file_exists(filename: str):
|
||||||
|
assert Path(filename).is_file(), (
|
||||||
|
f"{filename} does not exist!\n"
|
||||||
|
"Please refer to "
|
||||||
|
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wave_filename:
|
||||||
|
Path to a wave file. It should be single channel and each sample should
|
||||||
|
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||||
|
normalized to the range [-1, 1].
|
||||||
|
- sample rate of the wave file
|
||||||
|
"""
|
||||||
|
|
||||||
|
with wave.open(wave_filename) as f:
|
||||||
|
assert f.getnchannels() == 1, f.getnchannels()
|
||||||
|
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||||
|
num_samples = f.getnframes()
|
||||||
|
samples = f.readframes(num_samples)
|
||||||
|
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||||
|
samples_float32 = samples_int16.astype(np.float32)
|
||||||
|
|
||||||
|
samples_float32 = samples_float32 / 32768
|
||||||
|
return samples_float32, f.getframerate()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert_file_exists(args.whisper_encoder)
|
||||||
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
assert args.num_threads > 0, args.num_threads
|
||||||
|
config = sherpa_onnx.SpokenLanguageIdentificationConfig(
|
||||||
|
whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig(
|
||||||
|
encoder=args.whisper_encoder,
|
||||||
|
decoder=args.whisper_decoder,
|
||||||
|
),
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
debug=args.debug,
|
||||||
|
provider=args.provider,
|
||||||
|
)
|
||||||
|
slid = sherpa_onnx.SpokenLanguageIdentification(config)
|
||||||
|
|
||||||
|
samples, sample_rate = read_wave(args.sound_file)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
stream = slid.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
|
||||||
|
lang = slid.compute(stream)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
audio_duration = len(samples) / sample_rate
|
||||||
|
real_time_factor = elapsed_seconds / audio_duration
|
||||||
|
|
||||||
|
logging.info(f"File: {args.sound_file}")
|
||||||
|
logging.info(f"Detected language: {lang}")
|
||||||
|
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||||
|
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||||
|
logging.info(
|
||||||
|
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
||||||
38
setup.py
38
setup.py
@@ -1,8 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import setuptools
|
import setuptools
|
||||||
@@ -11,7 +9,7 @@ from cmake.cmake_extension import (
|
|||||||
BuildExtension,
|
BuildExtension,
|
||||||
bdist_wheel,
|
bdist_wheel,
|
||||||
cmake_extension,
|
cmake_extension,
|
||||||
enable_alsa,
|
get_binaries,
|
||||||
is_windows,
|
is_windows,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,39 +40,7 @@ def get_binaries_to_install():
|
|||||||
bin_dir.mkdir(parents=True, exist_ok=True)
|
bin_dir.mkdir(parents=True, exist_ok=True)
|
||||||
suffix = ".exe" if is_windows() else ""
|
suffix = ".exe" if is_windows() else ""
|
||||||
|
|
||||||
# Remember to also change cmake/cmake_extension.py
|
binaries = get_binaries()
|
||||||
binaries = ["sherpa-onnx"]
|
|
||||||
binaries += ["sherpa-onnx-keyword-spotter"]
|
|
||||||
binaries += ["sherpa-onnx-offline"]
|
|
||||||
binaries += ["sherpa-onnx-microphone"]
|
|
||||||
binaries += ["sherpa-onnx-microphone-offline"]
|
|
||||||
binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]
|
|
||||||
binaries += ["sherpa-onnx-online-websocket-server"]
|
|
||||||
binaries += ["sherpa-onnx-offline-websocket-server"]
|
|
||||||
binaries += ["sherpa-onnx-online-websocket-client"]
|
|
||||||
binaries += ["sherpa-onnx-vad-microphone"]
|
|
||||||
binaries += ["sherpa-onnx-vad-microphone-offline-asr"]
|
|
||||||
binaries += ["sherpa-onnx-offline-tts"]
|
|
||||||
binaries += ["sherpa-onnx-offline-tts-play"]
|
|
||||||
|
|
||||||
if enable_alsa():
|
|
||||||
binaries += ["sherpa-onnx-alsa"]
|
|
||||||
binaries += ["sherpa-onnx-alsa-offline"]
|
|
||||||
binaries += ["sherpa-onnx-offline-tts-play-alsa"]
|
|
||||||
binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]
|
|
||||||
|
|
||||||
if is_windows():
|
|
||||||
binaries += ["kaldi-native-fbank-core.dll"]
|
|
||||||
binaries += ["sherpa-onnx-c-api.dll"]
|
|
||||||
binaries += ["sherpa-onnx-core.dll"]
|
|
||||||
binaries += ["sherpa-onnx-portaudio.dll"]
|
|
||||||
binaries += ["onnxruntime.dll"]
|
|
||||||
binaries += ["piper_phonemize.dll"]
|
|
||||||
binaries += ["espeak-ng.dll"]
|
|
||||||
binaries += ["ucd.dll"]
|
|
||||||
binaries += ["kaldi-decoder-core.dll"]
|
|
||||||
binaries += ["sherpa-onnx-fst.lib"]
|
|
||||||
binaries += ["sherpa-onnx-kaldifst-core.lib"]
|
|
||||||
|
|
||||||
exe = []
|
exe = []
|
||||||
for f in binaries:
|
for f in binaries:
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ set(sources
|
|||||||
silero-vad-model-config.cc
|
silero-vad-model-config.cc
|
||||||
silero-vad-model.cc
|
silero-vad-model.cc
|
||||||
slice.cc
|
slice.cc
|
||||||
|
spoken-language-identification-impl.cc
|
||||||
|
spoken-language-identification.cc
|
||||||
stack.cc
|
stack.cc
|
||||||
symbol-table.cc
|
symbol-table.cc
|
||||||
text-utils.cc
|
text-utils.cc
|
||||||
@@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
|||||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||||
|
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||||
|
|
||||||
set(main_exes
|
set(main_exes
|
||||||
sherpa-onnx
|
sherpa-onnx
|
||||||
@@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
|||||||
sherpa-onnx-offline
|
sherpa-onnx-offline
|
||||||
sherpa-onnx-offline-parallel
|
sherpa-onnx-offline-parallel
|
||||||
sherpa-onnx-offline-tts
|
sherpa-onnx-offline-tts
|
||||||
|
sherpa-onnx-offline-language-identification
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(exe IN LISTS main_exes)
|
foreach(exe IN LISTS main_exes)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ enum class ModelType {
|
|||||||
kTdnn,
|
kTdnn,
|
||||||
kZipformerCtc,
|
kZipformerCtc,
|
||||||
kWenetCtc,
|
kWenetCtc,
|
||||||
kUnkown,
|
kUnknown,
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
"run.sh\n"
|
"run.sh\n"
|
||||||
"\n"
|
"\n"
|
||||||
"for how to add metadta to model.onnx\n");
|
"for how to add metadta to model.onnx\n");
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||||
@@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kWenetCtc;
|
return ModelType::kWenetCtc;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||||
const OfflineModelConfig &config) {
|
const OfflineModelConfig &config) {
|
||||||
ModelType model_type = ModelType::kUnkown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
std::string filename;
|
std::string filename;
|
||||||
if (!config.nemo_ctc.model.empty()) {
|
if (!config.nemo_ctc.model.empty()) {
|
||||||
@@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
case ModelType::kWenetCtc:
|
case ModelType::kWenetCtc:
|
||||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||||
break;
|
break;
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnknown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
|
|
||||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||||
AAssetManager *mgr, const OfflineModelConfig &config) {
|
AAssetManager *mgr, const OfflineModelConfig &config) {
|
||||||
ModelType model_type = ModelType::kUnkown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
std::string filename;
|
std::string filename;
|
||||||
if (!config.nemo_ctc.model.empty()) {
|
if (!config.nemo_ctc.model.empty()) {
|
||||||
@@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
|||||||
case ModelType::kWenetCtc:
|
case ModelType::kWenetCtc:
|
||||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||||
break;
|
break;
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnknown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
num_frames = max_num_frames - 50;
|
num_frames = max_num_frames - 50;
|
||||||
}
|
}
|
||||||
|
|
||||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||||
|
|
||||||
// note that 1000 is an experience-value.
|
// note that 1000 is an experience-value.
|
||||||
// You can replace 1000 by other values, say, 100.
|
// You can replace 1000 by other values, say, 100.
|
||||||
@@ -162,38 +162,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
static void NormalizeFeatures(float *features, int32_t num_frames,
|
|
||||||
int32_t feat_dim) {
|
|
||||||
// log_spec = torch.clamp(features, min=1e-10).log10()
|
|
||||||
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
||||||
// mel = (log_spec + 4.0) / 4.0
|
|
||||||
|
|
||||||
int32_t n = num_frames * feat_dim;
|
|
||||||
float max_v = -1e20;
|
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
|
||||||
float f = features[i];
|
|
||||||
|
|
||||||
f = std::max<float>(f, 1e-10);
|
|
||||||
f = std::log10(f);
|
|
||||||
|
|
||||||
max_v = std::max(f, max_v);
|
|
||||||
|
|
||||||
features[i] = f;
|
|
||||||
}
|
|
||||||
|
|
||||||
max_v -= 8;
|
|
||||||
|
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
|
||||||
float f = features[i];
|
|
||||||
f = std::max(f, max_v);
|
|
||||||
|
|
||||||
f = (f + 4) / 4;
|
|
||||||
|
|
||||||
features[i] = f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineRecognizerConfig config_;
|
OfflineRecognizerConfig config_;
|
||||||
SymbolTable symbol_table_;
|
SymbolTable symbol_table_;
|
||||||
|
|||||||
@@ -12,56 +12,6 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
|
||||||
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
|
||||||
int64_t token_val = model_->SOT();
|
|
||||||
std::array<int64_t, 2> token_shape{1, 1};
|
|
||||||
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
|
||||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
|
||||||
|
|
||||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
|
||||||
|
|
||||||
std::array<int64_t, 1> offset_shape{1};
|
|
||||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
|
||||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
|
||||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
|
||||||
|
|
||||||
auto decoder_out = model_->ForwardDecoder(
|
|
||||||
std::move(tokens), std::move(self_kv_cache.first),
|
|
||||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
|
||||||
std::move(offset));
|
|
||||||
|
|
||||||
cross_k = std::move(std::get<3>(decoder_out));
|
|
||||||
cross_v = std::move(std::get<4>(decoder_out));
|
|
||||||
|
|
||||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
|
||||||
int32_t vocab_size = model_->VocabSize();
|
|
||||||
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
|
||||||
|
|
||||||
int32_t lang_id = all_language_ids[0];
|
|
||||||
float this_logit = p_logits[lang_id];
|
|
||||||
|
|
||||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
|
||||||
int32_t id = all_language_ids[i];
|
|
||||||
float p = p_logits[id];
|
|
||||||
|
|
||||||
if (p > this_logit) {
|
|
||||||
this_logit = p;
|
|
||||||
lang_id = id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#if 1
|
|
||||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
|
||||||
model_->GetID2Lang().at(lang_id).c_str());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return lang_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult>
|
std::vector<OfflineWhisperDecoderResult>
|
||||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) {
|
Ort::Value cross_v) {
|
||||||
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||||
initial_tokens[1] = lang_id;
|
initial_tokens[1] = lang_id;
|
||||||
} else {
|
} else {
|
||||||
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
|
||||||
|
|
||||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||||
initial_tokens[1] = lang_id;
|
initial_tokens[1] = lang_id;
|
||||||
|
|||||||
@@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
|||||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) override;
|
Ort::Value cross_v) override;
|
||||||
|
|
||||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
|
||||||
Ort::Value &cross_v) const; // NOLINT
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineWhisperModelConfig config_;
|
OfflineWhisperModelConfig config_;
|
||||||
OfflineWhisperModel *model_; // not owned
|
OfflineWhisperModel *model_; // not owned
|
||||||
|
|||||||
@@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
po->Register(
|
po->Register(
|
||||||
"whisper-tail-paddings", &tail_paddings,
|
"whisper-tail-paddings", &tail_paddings,
|
||||||
"Suggest value: 50 for English models. 300 for multilingual models. "
|
"Suggested value: 50 for English models. 300 for multilingual models. "
|
||||||
"Since we have removed the 30-second constraint, we need to add some "
|
"Since we have removed the 30-second constraint, we need to add some "
|
||||||
"tail padding frames "
|
"tail padding frames "
|
||||||
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
|
"so that whisper can detect the eot token. Leave it to -1 to use 1000.");
|
||||||
"English models and 300 for multilingual models.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OfflineWhisperModelConfig::Validate() const {
|
bool OfflineWhisperModelConfig::Validate() const {
|
||||||
|
if (encoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!FileExists(encoder)) {
|
if (!FileExists(encoder)) {
|
||||||
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (decoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!FileExists(decoder)) {
|
if (!FileExists(decoder)) {
|
||||||
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
|
|||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
sess_opts_(GetSessionOptions(config)),
|
sess_opts_(GetSessionOptions(config)),
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
|
debug_ = config_.debug;
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.whisper.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.whisper.decoder);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit Impl(const SpokenLanguageIdentificationConfig &config)
|
||||||
|
: lid_config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
debug_ = config_.debug;
|
||||||
{
|
{
|
||||||
auto buf = ReadFile(config.whisper.encoder);
|
auto buf = ReadFile(config.whisper.encoder);
|
||||||
InitEncoder(buf.data(), buf.size());
|
InitEncoder(buf.data(), buf.size());
|
||||||
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
sess_opts_(GetSessionOptions(config)),
|
sess_opts_(GetSessionOptions(config)),
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
|
debug_ = config_.debug;
|
||||||
{
|
{
|
||||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||||
InitEncoder(buf.data(), buf.size());
|
InitEncoder(buf.data(), buf.size());
|
||||||
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
|
|||||||
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||||
|
Ort::Value &cross_v) { // NOLINT
|
||||||
|
int64_t token_val = SOT();
|
||||||
|
std::array<int64_t, 2> token_shape{1, 1};
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
|
auto self_kv_cache = GetInitialSelfKVCache();
|
||||||
|
|
||||||
|
std::array<int64_t, 1> offset_shape{1};
|
||||||
|
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
Allocator(), offset_shape.data(), offset_shape.size());
|
||||||
|
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||||
|
|
||||||
|
auto decoder_out =
|
||||||
|
ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
|
||||||
|
std::move(self_kv_cache.second), std::move(cross_k),
|
||||||
|
std::move(cross_v), std::move(offset));
|
||||||
|
|
||||||
|
cross_k = std::move(std::get<3>(decoder_out));
|
||||||
|
cross_v = std::move(std::get<4>(decoder_out));
|
||||||
|
|
||||||
|
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||||
|
int32_t vocab_size = VocabSize();
|
||||||
|
const auto &all_language_ids = GetAllLanguageIDs();
|
||||||
|
|
||||||
|
int32_t lang_id = all_language_ids[0];
|
||||||
|
float this_logit = p_logits[lang_id];
|
||||||
|
|
||||||
|
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||||
|
int32_t id = all_language_ids[i];
|
||||||
|
float p = p_logits[id];
|
||||||
|
|
||||||
|
if (p > this_logit) {
|
||||||
|
this_logit = p;
|
||||||
|
lang_id = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (debug_) {
|
||||||
|
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||||
|
GetID2Lang().at(lang_id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return lang_id;
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
||||||
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
|
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
|
||||||
|
|
||||||
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
|
|
||||||
// get meta data
|
// get meta data
|
||||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
if (config_.debug) {
|
if (debug_) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "---encoder---\n";
|
os << "---encoder---\n";
|
||||||
PrintModelMetadata(os, meta_data);
|
PrintModelMetadata(os, meta_data);
|
||||||
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineModelConfig config_;
|
OfflineModelConfig config_;
|
||||||
|
SpokenLanguageIdentificationConfig lid_config_;
|
||||||
|
bool debug_ = false;
|
||||||
Ort::Env env_;
|
Ort::Env env_;
|
||||||
Ort::SessionOptions sess_opts_;
|
Ort::SessionOptions sess_opts_;
|
||||||
Ort::AllocatorWithDefaultOptions allocator_;
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
|
|||||||
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||||
: impl_(std::make_unique<Impl>(config)) {}
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
OfflineWhisperModel::OfflineWhisperModel(
|
||||||
|
const SpokenLanguageIdentificationConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
|
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
|
||||||
const OfflineModelConfig &config)
|
const OfflineModelConfig &config)
|
||||||
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
|||||||
std::move(n_layer_cross_v), std::move(offset));
|
std::move(n_layer_cross_v), std::move(offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||||
|
Ort::Value &cross_v) { // NOLINT
|
||||||
|
return impl_->DetectLanguage(cross_k, cross_v);
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
||||||
const {
|
const {
|
||||||
return impl_->GetInitialSelfKVCache();
|
return impl_->GetInitialSelfKVCache();
|
||||||
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
|
|||||||
return impl_->IsMultiLingual();
|
return impl_->IsMultiLingual();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
|
||||||
|
int32_t feat_dim) {
|
||||||
|
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||||
|
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
// mel = (log_spec + 4.0) / 4.0
|
||||||
|
|
||||||
|
int32_t n = num_frames * feat_dim;
|
||||||
|
float max_v = -1e20;
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
float f = features[i];
|
||||||
|
|
||||||
|
f = std::max<float>(f, 1e-10);
|
||||||
|
f = std::log10(f);
|
||||||
|
|
||||||
|
max_v = std::max(f, max_v);
|
||||||
|
|
||||||
|
features[i] = f;
|
||||||
|
}
|
||||||
|
|
||||||
|
max_v -= 8;
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
float f = features[i];
|
||||||
|
f = std::max(f, max_v);
|
||||||
|
|
||||||
|
f = (f + 4) / 4;
|
||||||
|
|
||||||
|
features[i] = f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -25,6 +26,9 @@ class OfflineWhisperModel {
|
|||||||
public:
|
public:
|
||||||
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
explicit OfflineWhisperModel(
|
||||||
|
const SpokenLanguageIdentificationConfig &config);
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||||
#endif
|
#endif
|
||||||
@@ -72,7 +76,8 @@ class OfflineWhisperModel {
|
|||||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||||
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||||
|
|
||||||
int32_t DetectLanguage() const;
|
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||||
|
Ort::Value &cross_v); // NOLINT
|
||||||
|
|
||||||
/** Return the initial self kv cache in a pair
|
/** Return the initial self kv cache in a pair
|
||||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||||
@@ -98,6 +103,9 @@ class OfflineWhisperModel {
|
|||||||
int32_t Translate() const;
|
int32_t Translate() const;
|
||||||
bool IsMultiLingual() const;
|
bool IsMultiLingual() const;
|
||||||
|
|
||||||
|
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||||
|
int32_t feat_dim);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ enum class ModelType {
|
|||||||
kLstm,
|
kLstm,
|
||||||
kZipformer,
|
kZipformer,
|
||||||
kZipformer2,
|
kZipformer2,
|
||||||
kUnkown,
|
kUnknown,
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
"No model_type in the metadata!\n"
|
"No model_type in the metadata!\n"
|
||||||
"Please make sure you are using the latest export-onnx.py from icefall "
|
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||||
"to export your transducer models");
|
"to export your transducer models");
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type.get() == std::string("conformer")) {
|
if (model_type.get() == std::string("conformer")) {
|
||||||
@@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kZipformer2;
|
return ModelType::kZipformer2;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
|||||||
model_type.c_str());
|
model_type.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType model_type = ModelType::kUnkown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto buffer = ReadFile(config.transducer.encoder);
|
auto buffer = ReadFile(config.transducer.encoder);
|
||||||
@@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
|||||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||||
case ModelType::kZipformer2:
|
case ModelType::kZipformer2:
|
||||||
return std::make_unique<OnlineZipformer2TransducerModel>(config);
|
return std::make_unique<OnlineZipformer2TransducerModel>(config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnknown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
|||||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||||
case ModelType::kZipformer2:
|
case ModelType::kZipformer2:
|
||||||
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
|
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnknown:
|
||||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
|
|||||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(
|
||||||
|
const SpokenLanguageIdentificationConfig &config) {
|
||||||
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
|||||||
|
|
||||||
Ort::SessionOptions GetSessionOptions(
|
Ort::SessionOptions GetSessionOptions(
|
||||||
const SpeakerEmbeddingExtractorConfig &config);
|
const SpeakerEmbeddingExtractorConfig &config);
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(
|
||||||
|
const SpokenLanguageIdentificationConfig &config);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||||
|
|||||||
107
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
Normal file
107
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <chrono> // NOLINT
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||||
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
|
int main(int32_t argc, char *argv[]) {
|
||||||
|
const char *kUsageMessage = R"usage(
|
||||||
|
Spoken language identification with sherpa-onnx.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) Use a whisper multilingual model
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
|
||||||
|
rm sherpa-onnx-whisper-tiny.tar.bz2
|
||||||
|
|
||||||
|
We only use the int8.onnx models below.
|
||||||
|
|
||||||
|
./bin/sherpa-onnx-offline-spoken-language-identification \
|
||||||
|
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
|
||||||
|
--num-threads=1 \
|
||||||
|
/path/to/foo.wav
|
||||||
|
|
||||||
|
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||||
|
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||||
|
You can find test waves for different languages at
|
||||||
|
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
|
||||||
|
Note that only whisper multilingual models are supported. For instance,
|
||||||
|
"tiny" is supported but "tiny.en" is not.
|
||||||
|
for a list of pre-trained models to download.
|
||||||
|
)usage";
|
||||||
|
|
||||||
|
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||||
|
sherpa_onnx::SpokenLanguageIdentificationConfig config;
|
||||||
|
config.Register(&po);
|
||||||
|
|
||||||
|
po.Read(argc, argv);
|
||||||
|
if (po.NumArgs() != 1) {
|
||||||
|
fprintf(stderr, "Error: Please provide 1 wave file.\n\n");
|
||||||
|
po.PrintUsage();
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||||
|
|
||||||
|
if (!config.Validate()) {
|
||||||
|
fprintf(stderr, "Errors in config!\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "Creating spoken language identifier ...\n");
|
||||||
|
sherpa_onnx::SpokenLanguageIdentification slid(config);
|
||||||
|
|
||||||
|
fprintf(stderr, "Started\n");
|
||||||
|
const std::string wav_filename = po.GetArg(1);
|
||||||
|
|
||||||
|
int32_t sampling_rate = -1;
|
||||||
|
bool is_ok = false;
|
||||||
|
const std::vector<float> samples =
|
||||||
|
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||||
|
if (!is_ok) {
|
||||||
|
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||||
|
|
||||||
|
const auto begin = std::chrono::steady_clock::now();
|
||||||
|
|
||||||
|
auto s = slid.CreateStream();
|
||||||
|
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||||
|
|
||||||
|
auto language = slid.Compute(s.get());
|
||||||
|
|
||||||
|
const auto end = std::chrono::steady_clock::now();
|
||||||
|
|
||||||
|
fprintf(stderr, "Done!\n\n");
|
||||||
|
fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(),
|
||||||
|
language.c_str());
|
||||||
|
|
||||||
|
float elapsed_seconds =
|
||||||
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||||
|
.count() /
|
||||||
|
1000.;
|
||||||
|
|
||||||
|
fprintf(stderr, "num threads: %d\n", config.num_threads);
|
||||||
|
|
||||||
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||||
|
float rtf = elapsed_seconds / duration;
|
||||||
|
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||||
|
elapsed_seconds, duration, rtf);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -16,7 +16,7 @@ enum class ModelType {
|
|||||||
kWeSpeaker,
|
kWeSpeaker,
|
||||||
k3dSpeaker,
|
k3dSpeaker,
|
||||||
kNeMo,
|
kNeMo,
|
||||||
kUnkown,
|
kUnknown,
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
||||||
"add_meta_data.py"
|
"add_meta_data.py"
|
||||||
"to add metadata to models from WeSpeaker\n");
|
"to add metadata to models from WeSpeaker\n");
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type.get() == std::string("wespeaker")) {
|
if (model_type.get() == std::string("wespeaker")) {
|
||||||
@@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kNeMo;
|
return ModelType::kNeMo;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||||
SpeakerEmbeddingExtractorImpl::Create(
|
SpeakerEmbeddingExtractorImpl::Create(
|
||||||
const SpeakerEmbeddingExtractorConfig &config) {
|
const SpeakerEmbeddingExtractorConfig &config) {
|
||||||
ModelType model_type = ModelType::kUnkown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto buffer = ReadFile(config.model);
|
auto buffer = ReadFile(config.model);
|
||||||
@@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
|
|||||||
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
|
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
|
||||||
case ModelType::kNeMo:
|
case ModelType::kNeMo:
|
||||||
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
|
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnknown:
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!");
|
||||||
"Unknown model type in for speaker embedding extractor!");
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
|
|||||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||||
SpeakerEmbeddingExtractorImpl::Create(
|
SpeakerEmbeddingExtractorImpl::Create(
|
||||||
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
|
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
|
||||||
ModelType model_type = ModelType::kUnkown;
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto buffer = ReadFile(mgr, config.model);
|
auto buffer = ReadFile(mgr, config.model);
|
||||||
@@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
|
|||||||
config);
|
config);
|
||||||
case ModelType::kNeMo:
|
case ModelType::kNeMo:
|
||||||
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
|
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnknown:
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Unknown model type in for speaker embedding extractor!");
|
"Unknown model type in for speaker embedding extractor!");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
88
sherpa-onnx/csrc/spoken-language-identification-impl.cc
Normal file
88
sherpa-onnx/csrc/spoken-language-identification-impl.cc
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
enum class ModelType {
|
||||||
|
kWhisper,
|
||||||
|
kUnknown,
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||||
|
bool debug) {
|
||||||
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||||
|
Ort::SessionOptions sess_opts;
|
||||||
|
|
||||||
|
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||||
|
sess_opts);
|
||||||
|
|
||||||
|
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||||
|
if (debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
auto model_type =
|
||||||
|
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||||
|
if (!model_type) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"No model_type in the metadata!\n"
|
||||||
|
"Please make sure you have added metadata to the model.\n\n"
|
||||||
|
"For instance, you can use\n"
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
|
||||||
|
"export-onnx.py "
|
||||||
|
"to add metadata to models from whisper\n");
|
||||||
|
return ModelType::kUnknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto model_type_str = std::string(model_type.get());
|
||||||
|
if (model_type_str.find("whisper") == 0) {
|
||||||
|
return ModelType::kWhisper;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
|
return ModelType::kUnknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SpokenLanguageIdentificationImpl>
|
||||||
|
SpokenLanguageIdentificationImpl::Create(
|
||||||
|
const SpokenLanguageIdentificationConfig &config) {
|
||||||
|
ModelType model_type = ModelType::kUnknown;
|
||||||
|
{
|
||||||
|
if (config.whisper.encoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Only whisper models are supported at present");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
auto buffer = ReadFile(config.whisper.encoder);
|
||||||
|
|
||||||
|
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (model_type) {
|
||||||
|
case ModelType::kWhisper:
|
||||||
|
return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
|
||||||
|
case ModelType::kUnknown:
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Unknown model type for spoken language identification!");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unreachable code
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
28
sherpa-onnx/csrc/spoken-language-identification-impl.h
Normal file
28
sherpa-onnx/csrc/spoken-language-identification-impl.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/spoken-language-identification-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpokenLanguageIdentificationImpl {
|
||||||
|
public:
|
||||||
|
virtual ~SpokenLanguageIdentificationImpl() = default;
|
||||||
|
|
||||||
|
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
|
||||||
|
const SpokenLanguageIdentificationConfig &config);
|
||||||
|
|
||||||
|
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||||
|
|
||||||
|
virtual std::string Compute(OfflineStream *s) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||||
119
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
Normal file
119
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpokenLanguageIdentificationWhisperImpl
|
||||||
|
: public SpokenLanguageIdentificationImpl {
|
||||||
|
public:
|
||||||
|
explicit SpokenLanguageIdentificationWhisperImpl(
|
||||||
|
const SpokenLanguageIdentificationConfig &config)
|
||||||
|
: config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
|
||||||
|
Check();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(WhisperTag{});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Compute(OfflineStream *s) const override {
|
||||||
|
int32_t max_num_frames = 3000;
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = s->FeatureDim();
|
||||||
|
std::vector<float> f = s->GetFrames();
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
|
// we use 50 here so that there will be some zero tail paddings
|
||||||
|
if (num_frames >= max_num_frames - 50) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Only waves less than 30 seconds are supported. We process only the "
|
||||||
|
"first 30 seconds and discard the remaining data");
|
||||||
|
num_frames = max_num_frames - 50;
|
||||||
|
}
|
||||||
|
|
||||||
|
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||||
|
|
||||||
|
// note that 1000 is an experience-value.
|
||||||
|
// You can replace 1000 by other values, say, 100.
|
||||||
|
//
|
||||||
|
// Since we have removed the 30 seconds constraint, we need
|
||||||
|
// tail_padding_frames so that whisper is able to detect the eot token.
|
||||||
|
int32_t tail_padding_frames = 1000;
|
||||||
|
|
||||||
|
if (config_.whisper.tail_paddings > 0) {
|
||||||
|
tail_padding_frames = config_.whisper.tail_paddings;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t actual_frames =
|
||||||
|
std::min(num_frames + tail_padding_frames, max_num_frames);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||||
|
model_->Allocator(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
float *p_mel = mel.GetTensorMutableData<float>();
|
||||||
|
std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
|
||||||
|
|
||||||
|
std::fill_n(p_mel + num_frames * feat_dim,
|
||||||
|
(actual_frames - num_frames) * feat_dim, 0);
|
||||||
|
|
||||||
|
mel = Transpose12(model_->Allocator(), &mel);
|
||||||
|
|
||||||
|
try {
|
||||||
|
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||||
|
int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
|
||||||
|
const auto &id2lang = model_->GetID2Lang();
|
||||||
|
if (id2lang.count(lang_id)) {
|
||||||
|
return id2lang.at(lang_id);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
|
||||||
|
lang_id);
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
} catch (const Ort::Exception &ex) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
|
||||||
|
"input frames: %d, Current tail "
|
||||||
|
"paddings: %d. If you see a lot of such exceptions, please consider "
|
||||||
|
"using a larger --whisper-tail-paddings",
|
||||||
|
ex.what(), num_frames, tail_padding_frames);
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Check() const {
|
||||||
|
if (!model_->IsMultiLingual()) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Only whisper multilingual models can be used for spoken language "
|
||||||
|
"identification. Given: %s,%s",
|
||||||
|
config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SpokenLanguageIdentificationConfig config_;
|
||||||
|
std::unique_ptr<OfflineWhisperModel> model_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||||
117
sherpa-onnx/csrc/spoken-language-identification.cc
Normal file
117
sherpa-onnx/csrc/spoken-language-identification.cc
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
// sherpa-onnx/csrc/spoken-language-identification.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register(
|
||||||
|
"whisper-encoder", &encoder,
|
||||||
|
"Path to then encoder of a whisper multilingual model. Support only "
|
||||||
|
"tiny, base, small, medium, large.");
|
||||||
|
|
||||||
|
po->Register(
|
||||||
|
"whisper-decoder", &decoder,
|
||||||
|
"Path to the decoder of a whisper multilingual model. Support only "
|
||||||
|
"tiny, base, small, medium, large.");
|
||||||
|
|
||||||
|
po->Register(
|
||||||
|
"whisper-tail-paddings", &tail_paddings,
|
||||||
|
"Suggested value: 300 for multilingual models. "
|
||||||
|
"Since we have removed the 30-second constraint, we need to add some "
|
||||||
|
"tail padding frames "
|
||||||
|
"so that whisper can detect the eot token. Leave it to -1 to use 1000");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
|
||||||
|
if (encoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(encoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (decoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(decoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SpokenLanguageIdentificationWhisperConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "SpokenLanguageIdentificationWhisperConfig(";
|
||||||
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
|
os << "decoder=\"" << decoder << "\", ";
|
||||||
|
os << "tail_paddings=" << tail_paddings << ")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) {
|
||||||
|
whisper.Register(po);
|
||||||
|
|
||||||
|
po->Register("num-threads", &num_threads,
|
||||||
|
"Number of threads to run the neural network");
|
||||||
|
|
||||||
|
po->Register("debug", &debug,
|
||||||
|
"true to print model information while loading it.");
|
||||||
|
|
||||||
|
po->Register("provider", &provider,
|
||||||
|
"Specify a provider to use: cpu, cuda, coreml");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpokenLanguageIdentificationConfig::Validate() const {
|
||||||
|
if (!whisper.Validate()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SpokenLanguageIdentificationConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "SpokenLanguageIdentificationConfig(";
|
||||||
|
os << "whisper=\"" << whisper.ToString() << "\", ";
|
||||||
|
os << "num_threads=" << num_threads << ", ";
|
||||||
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
|
os << "provider=\"" << provider << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
SpokenLanguageIdentification::SpokenLanguageIdentification(
|
||||||
|
const SpokenLanguageIdentificationConfig &config)
|
||||||
|
: impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
|
||||||
|
|
||||||
|
SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
|
||||||
|
const {
|
||||||
|
return impl_->CreateStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const {
|
||||||
|
return impl_->Compute(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
89
sherpa-onnx/csrc/spoken-language-identification.h
Normal file
89
sherpa-onnx/csrc/spoken-language-identification.h
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
// sherpa-onnx/csrc/spoken-language-identification.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct SpokenLanguageIdentificationWhisperConfig {
|
||||||
|
// Requires a multi-lingual whisper model.
|
||||||
|
// That is, it supports only tiny, base, small, medium, large.
|
||||||
|
// Note: It does NOT support tiny.en, base.en, small.en, medium.en
|
||||||
|
std::string encoder;
|
||||||
|
std::string decoder;
|
||||||
|
|
||||||
|
// Number of tail padding frames.
|
||||||
|
//
|
||||||
|
// Since we remove the 30-second constraint, we need to add some paddings
|
||||||
|
// at the end.
|
||||||
|
//
|
||||||
|
// Recommended values:
|
||||||
|
// - 50 for English models
|
||||||
|
// - 300 for multilingual models
|
||||||
|
int32_t tail_paddings = -1;
|
||||||
|
|
||||||
|
SpokenLanguageIdentificationWhisperConfig() = default;
|
||||||
|
|
||||||
|
SpokenLanguageIdentificationWhisperConfig(const std::string &encoder,
|
||||||
|
const std::string &decoder,
|
||||||
|
int32_t tail_paddings)
|
||||||
|
: encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SpokenLanguageIdentificationConfig {
|
||||||
|
SpokenLanguageIdentificationWhisperConfig whisper;
|
||||||
|
|
||||||
|
int32_t num_threads = 1;
|
||||||
|
bool debug = false;
|
||||||
|
std::string provider = "cpu";
|
||||||
|
|
||||||
|
SpokenLanguageIdentificationConfig() = default;
|
||||||
|
|
||||||
|
SpokenLanguageIdentificationConfig(
|
||||||
|
const SpokenLanguageIdentificationWhisperConfig &whisper,
|
||||||
|
int32_t num_threads, bool debug, const std::string &provider)
|
||||||
|
: whisper(whisper),
|
||||||
|
num_threads(num_threads),
|
||||||
|
debug(debug),
|
||||||
|
provider(provider) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SpokenLanguageIdentificationImpl;
|
||||||
|
|
||||||
|
class SpokenLanguageIdentification {
|
||||||
|
public:
|
||||||
|
explicit SpokenLanguageIdentification(
|
||||||
|
const SpokenLanguageIdentificationConfig &config);
|
||||||
|
|
||||||
|
~SpokenLanguageIdentification();
|
||||||
|
|
||||||
|
// Create a stream to accept audio samples and compute features
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||||
|
|
||||||
|
// Return a string containing the language, e.g., en, zh, de,
|
||||||
|
// etc.
|
||||||
|
// Note: en is for English, zh is for Chinese, de is for German, etc.
|
||||||
|
std::string Compute(OfflineStream *s) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<SpokenLanguageIdentificationImpl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||||
@@ -33,6 +33,7 @@ set(srcs
|
|||||||
silero-vad-model-config.cc
|
silero-vad-model-config.cc
|
||||||
speaker-embedding-extractor.cc
|
speaker-embedding-extractor.cc
|
||||||
speaker-embedding-manager.cc
|
speaker-embedding-manager.cc
|
||||||
|
spoken-language-identification.cc
|
||||||
vad-model-config.cc
|
vad-model-config.cc
|
||||||
vad-model.cc
|
vad-model.cc
|
||||||
voice-activity-detector.cc
|
voice-activity-detector.cc
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||||
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||||
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
|
||||||
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/vad-model.h"
|
#include "sherpa-onnx/python/csrc/vad-model.h"
|
||||||
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
||||||
@@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
|||||||
PybindOfflineTts(&m);
|
PybindOfflineTts(&m);
|
||||||
PybindSpeakerEmbeddingExtractor(&m);
|
PybindSpeakerEmbeddingExtractor(&m);
|
||||||
PybindSpeakerEmbeddingManager(&m);
|
PybindSpeakerEmbeddingManager(&m);
|
||||||
|
PybindSpokenLanguageIdentification(&m);
|
||||||
|
|
||||||
PybindAlsa(&m);
|
PybindAlsa(&m);
|
||||||
}
|
}
|
||||||
|
|||||||
60
sherpa-onnx/python/csrc/spoken-language-identification.cc
Normal file
60
sherpa-onnx/python/csrc/spoken-language-identification.cc
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// sherpa-onnx/python/csrc/spoken-language-identification.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) {
|
||||||
|
using PyClass = SpokenLanguageIdentificationWhisperConfig;
|
||||||
|
|
||||||
|
py::class_<PyClass>(*m, "SpokenLanguageIdentificationWhisperConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const std::string &, const std::string &, int32_t>(),
|
||||||
|
py::arg("encoder"), py::arg("decoder"),
|
||||||
|
py::arg("tail_paddings") = -1)
|
||||||
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
|
.def_readwrite("decoder", &PyClass::decoder)
|
||||||
|
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
|
||||||
|
.def("validate", &PyClass::Validate)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
|
||||||
|
PybindSpokenLanguageIdentificationWhisperConfig(m);
|
||||||
|
|
||||||
|
using PyClass = SpokenLanguageIdentificationConfig;
|
||||||
|
|
||||||
|
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
|
||||||
|
bool, const std::string>(),
|
||||||
|
py::arg("whisper"), py::arg("num_threads") = 1,
|
||||||
|
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||||
|
.def_readwrite("whisper", &PyClass::whisper)
|
||||||
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
.def_readwrite("provider", &PyClass::provider)
|
||||||
|
.def("validate", &PyClass::Validate)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PybindSpokenLanguageIdentification(py::module *m) {
|
||||||
|
PybindSpokenLanguageIdentificationConfig(m);
|
||||||
|
|
||||||
|
using PyClass = SpokenLanguageIdentification;
|
||||||
|
py::class_<PyClass>(*m, "SpokenLanguageIdentification")
|
||||||
|
.def(py::init<const SpokenLanguageIdentificationConfig &>(),
|
||||||
|
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("create_stream", &PyClass::CreateStream,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("compute", &PyClass::Compute,
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/spoken-language-identification.h
Normal file
16
sherpa-onnx/python/csrc/spoken-language-identification.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/spoken-language-identification.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindSpokenLanguageIdentification(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||||
@@ -13,6 +13,9 @@ from _sherpa_onnx import (
|
|||||||
SpeakerEmbeddingExtractorConfig,
|
SpeakerEmbeddingExtractorConfig,
|
||||||
SpeakerEmbeddingManager,
|
SpeakerEmbeddingManager,
|
||||||
SpeechSegment,
|
SpeechSegment,
|
||||||
|
SpokenLanguageIdentification,
|
||||||
|
SpokenLanguageIdentificationConfig,
|
||||||
|
SpokenLanguageIdentificationWhisperConfig,
|
||||||
VadModel,
|
VadModel,
|
||||||
VadModelConfig,
|
VadModelConfig,
|
||||||
VoiceActivityDetector,
|
VoiceActivityDetector,
|
||||||
|
|||||||
Reference in New Issue
Block a user