Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)
This commit is contained in:
63
.github/workflows/export-whisper-to-onnx.yaml
vendored
63
.github/workflows/export-whisper-to-onnx.yaml
vendored
@@ -15,9 +15,9 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [macos-latest]
|
||||||
# model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
|
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell", "large", "large-v1", "large-v2", "distil-large-v2"]
|
||||||
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"]
|
# model: ["large", "large-v1", "large-v2", "large-v3", "distil-large-v2"]
|
||||||
python-version: ["3.8"]
|
python-version: ["3.8"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -32,7 +32,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
python3 -m pip install openai-whisper==20230314 onnxruntime onnx
|
python3 -m pip install openai-whisper==20231117 onnxruntime onnx soundfile librosa
|
||||||
|
|
||||||
- name: export ${{ matrix.model }}
|
- name: export ${{ matrix.model }}
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -62,7 +62,6 @@ jobs:
|
|||||||
rm -fv medium-aishell-decoder.onnx
|
rm -fv medium-aishell-decoder.onnx
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
ls -lh ~/.cache/whisper || true
|
ls -lh ~/.cache/whisper || true
|
||||||
@@ -74,7 +73,8 @@ jobs:
|
|||||||
src=sherpa-onnx-whisper-${{ matrix.model }}
|
src=sherpa-onnx-whisper-${{ matrix.model }}
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
mv whisper $src
|
mkdir $src
|
||||||
|
mv -v whisper/$model* $src/
|
||||||
|
|
||||||
echo "------------------------------"
|
echo "------------------------------"
|
||||||
|
|
||||||
@@ -97,19 +97,16 @@ jobs:
|
|||||||
ls -lh $src
|
ls -lh $src
|
||||||
echo "--------------------"
|
echo "--------------------"
|
||||||
|
|
||||||
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
|
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
|
||||||
#tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2.
|
echo "Don't release model to github for large models. $model"
|
||||||
tar cvjf $src.tar.bz2 $src
|
|
||||||
split -b 1G $src.tar.bz2 $src.tar.bz2.
|
|
||||||
rm $src.tar.bz2
|
|
||||||
# cat $src.tar.gz.* | tar xjf -
|
|
||||||
else
|
else
|
||||||
tar cvjf $src.tar.bz2 $src
|
tar cvjf $src.tar.bz2 $src
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
|
|
||||||
- name: Release
|
- name: Release
|
||||||
|
if: matrix.model != 'large' && matrix.model != 'large-v1' && matrix.model != 'large-v2' && matrix.model != 'large-v3' && matrix.model != 'distil-large-v2'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
@@ -119,19 +116,6 @@ jobs:
|
|||||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
tag: asr-models
|
tag: asr-models
|
||||||
|
|
||||||
- name: Test ${{ matrix.model }}
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
python3 -m pip install kaldi-native-fbank
|
|
||||||
git checkout .
|
|
||||||
model=${{ matrix.model }}
|
|
||||||
src=sherpa-onnx-whisper-$model
|
|
||||||
python3 scripts/whisper/test.py \
|
|
||||||
--encoder $src/$model-encoder.int8.onnx \
|
|
||||||
--decoder $src/$model-decoder.int8.onnx \
|
|
||||||
--tokens $src/$model-tokens.txt \
|
|
||||||
$src/test_wavs/0.wav
|
|
||||||
|
|
||||||
- name: Publish ${{ matrix.model }} to huggingface
|
- name: Publish ${{ matrix.model }} to huggingface
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
@@ -144,27 +128,36 @@ jobs:
|
|||||||
|
|
||||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
|
||||||
|
git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
||||||
|
|
||||||
if [[ $model != medium-aishell ]]; then
|
if [[ $model != medium-aishell ]]; then
|
||||||
rm -rf huggingface/*
|
rm -rf huggingface/*
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
|
cp -av $src/* ./huggingface/
|
||||||
mv $src.tar* ./huggingface
|
|
||||||
else
|
|
||||||
cp -v $src/*.onnx ./huggingface
|
|
||||||
cp -v $src/*tokens* ./huggingface
|
|
||||||
cp -av $src/test_wavs ./huggingface
|
|
||||||
fi
|
|
||||||
|
|
||||||
cd huggingface
|
cd huggingface
|
||||||
|
|
||||||
git status
|
git status
|
||||||
ls -lh
|
ls -lh
|
||||||
git lfs track "*gz*"
|
|
||||||
git lfs track "*onnx*"
|
git lfs track "*onnx*"
|
||||||
|
git lfs track "*weights*"
|
||||||
|
|
||||||
git add .
|
git add .
|
||||||
git commit -m "upload ${{ matrix.model }}"
|
git commit -m "upload ${{ matrix.model }}"
|
||||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
||||||
|
|
||||||
|
- name: Test ${{ matrix.model }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m pip install kaldi-native-fbank
|
||||||
|
git checkout .
|
||||||
|
model=${{ matrix.model }}
|
||||||
|
src=sherpa-onnx-whisper-$model
|
||||||
|
time python3 scripts/whisper/test.py \
|
||||||
|
--encoder $src/$model-encoder.onnx \
|
||||||
|
--decoder $src/$model-decoder.onnx \
|
||||||
|
--tokens $src/$model-tokens.txt \
|
||||||
|
$src/test_wavs/0.wav
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
## 1.10.14 (to-be-released)
|
## 1.10.14
|
||||||
|
|
||||||
|
* Support whisper large v3
|
||||||
* Update onnxruntime from v1.18.0 to v1.18.1
|
* Update onnxruntime from v1.18.0 to v1.18.1
|
||||||
* Fix invalid utf8 sequence from Whisper for Dart API.
|
* Fix invalid utf8 sequence from Whisper for Dart API.
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ project(sherpa-onnx)
|
|||||||
# ./nodejs-addon-examples
|
# ./nodejs-addon-examples
|
||||||
# ./dart-api-examples/
|
# ./dart-api-examples/
|
||||||
# ./CHANGELOG.md
|
# ./CHANGELOG.md
|
||||||
set(SHERPA_ONNX_VERSION "1.10.13")
|
set(SHERPA_ONNX_VERSION "1.10.14")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
function(download_kaldi_native_fbank)
|
function(download_kaldi_native_fbank)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")
|
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
|
||||||
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")
|
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
|
||||||
set(kaldi_native_fbank_HASH "SHA256=335fe1daf1b9bfb2a7b6bf03b64c4c4686c39077c57fb8058c02611981676638")
|
set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9")
|
||||||
|
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
|
|||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldi-native-fbank
|
# please pre-download kaldi-native-fbank
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.19.3.tar.gz
|
$ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz
|
||||||
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.3.tar.gz
|
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz
|
||||||
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.3.tar.gz
|
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz
|
||||||
/tmp/kaldi-native-fbank-1.19.3.tar.gz
|
/tmp/kaldi-native-fbank-1.20.0.tar.gz
|
||||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.19.3.tar.gz
|
/star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ environment:
|
|||||||
|
|
||||||
# Add regular dependencies here.
|
# Add regular dependencies here.
|
||||||
dependencies:
|
dependencies:
|
||||||
sherpa_onnx: ^1.10.13
|
sherpa_onnx: ^1.10.14
|
||||||
path: ^1.9.0
|
path: ^1.9.0
|
||||||
args: ^2.5.0
|
args: ^2.5.0
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ environment:
|
|||||||
|
|
||||||
# Add regular dependencies here.
|
# Add regular dependencies here.
|
||||||
dependencies:
|
dependencies:
|
||||||
sherpa_onnx: ^1.10.13
|
sherpa_onnx: ^1.10.14
|
||||||
path: ^1.9.0
|
path: ^1.9.0
|
||||||
args: ^2.5.0
|
args: ^2.5.0
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ environment:
|
|||||||
|
|
||||||
# Add regular dependencies here.
|
# Add regular dependencies here.
|
||||||
dependencies:
|
dependencies:
|
||||||
sherpa_onnx: ^1.10.13
|
sherpa_onnx: ^1.10.14
|
||||||
path: ^1.9.0
|
path: ^1.9.0
|
||||||
args: ^2.5.0
|
args: ^2.5.0
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ environment:
|
|||||||
sdk: ^3.4.0
|
sdk: ^3.4.0
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
sherpa_onnx: ^1.10.13
|
sherpa_onnx: ^1.10.14
|
||||||
path: ^1.9.0
|
path: ^1.9.0
|
||||||
args: ^2.5.0
|
args: ^2.5.0
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ description: >
|
|||||||
|
|
||||||
publish_to: 'none'
|
publish_to: 'none'
|
||||||
|
|
||||||
version: 1.10.13
|
version: 1.10.14
|
||||||
|
|
||||||
topics:
|
topics:
|
||||||
- speech-recognition
|
- speech-recognition
|
||||||
@@ -30,7 +30,7 @@ dependencies:
|
|||||||
record: ^5.1.0
|
record: ^5.1.0
|
||||||
url_launcher: ^6.2.6
|
url_launcher: ^6.2.6
|
||||||
|
|
||||||
sherpa_onnx: ^1.10.13
|
sherpa_onnx: ^1.10.14
|
||||||
# sherpa_onnx:
|
# sherpa_onnx:
|
||||||
# path: ../../flutter/sherpa_onnx
|
# path: ../../flutter/sherpa_onnx
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ dependencies:
|
|||||||
cupertino_icons: ^1.0.6
|
cupertino_icons: ^1.0.6
|
||||||
path_provider: ^2.1.3
|
path_provider: ^2.1.3
|
||||||
path: ^1.9.0
|
path: ^1.9.0
|
||||||
sherpa_onnx: ^1.10.13
|
sherpa_onnx: ^1.10.14
|
||||||
url_launcher: ^6.2.6
|
url_launcher: ^6.2.6
|
||||||
audioplayers: ^5.0.0
|
audioplayers: ^5.0.0
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ topics:
|
|||||||
- voice-activity-detection
|
- voice-activity-detection
|
||||||
|
|
||||||
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
|
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
|
||||||
version: 1.10.13
|
version: 1.10.14
|
||||||
|
|
||||||
homepage: https://github.com/k2-fsa/sherpa-onnx
|
homepage: https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
|
||||||
@@ -30,19 +30,19 @@ dependencies:
|
|||||||
flutter:
|
flutter:
|
||||||
sdk: flutter
|
sdk: flutter
|
||||||
|
|
||||||
sherpa_onnx_android: ^1.10.13
|
sherpa_onnx_android: ^1.10.14
|
||||||
# path: ../sherpa_onnx_android
|
# path: ../sherpa_onnx_android
|
||||||
|
|
||||||
sherpa_onnx_macos: ^1.10.13
|
sherpa_onnx_macos: ^1.10.14
|
||||||
# path: ../sherpa_onnx_macos
|
# path: ../sherpa_onnx_macos
|
||||||
|
|
||||||
sherpa_onnx_linux: ^1.10.13
|
sherpa_onnx_linux: ^1.10.14
|
||||||
# path: ../sherpa_onnx_linux
|
# path: ../sherpa_onnx_linux
|
||||||
#
|
#
|
||||||
sherpa_onnx_windows: ^1.10.13
|
sherpa_onnx_windows: ^1.10.14
|
||||||
# path: ../sherpa_onnx_windows
|
# path: ../sherpa_onnx_windows
|
||||||
|
|
||||||
sherpa_onnx_ios: ^1.10.13
|
sherpa_onnx_ios: ^1.10.14
|
||||||
# sherpa_onnx_ios:
|
# sherpa_onnx_ios:
|
||||||
# path: ../sherpa_onnx_ios
|
# path: ../sherpa_onnx_ios
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
# https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
|
# https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
|
||||||
Pod::Spec.new do |s|
|
Pod::Spec.new do |s|
|
||||||
s.name = 'sherpa_onnx_ios'
|
s.name = 'sherpa_onnx_ios'
|
||||||
s.version = '1.10.13'
|
s.version = '1.10.14'
|
||||||
s.summary = 'A new Flutter FFI plugin project.'
|
s.summary = 'A new Flutter FFI plugin project.'
|
||||||
s.description = <<-DESC
|
s.description = <<-DESC
|
||||||
A new Flutter FFI plugin project.
|
A new Flutter FFI plugin project.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
#
|
#
|
||||||
Pod::Spec.new do |s|
|
Pod::Spec.new do |s|
|
||||||
s.name = 'sherpa_onnx_macos'
|
s.name = 'sherpa_onnx_macos'
|
||||||
s.version = '1.10.13'
|
s.version = '1.10.14'
|
||||||
s.summary = 'sherpa-onnx Flutter FFI plugin project.'
|
s.summary = 'sherpa-onnx Flutter FFI plugin project.'
|
||||||
s.description = <<-DESC
|
s.description = <<-DESC
|
||||||
sherpa-onnx Flutter FFI plugin project.
|
sherpa-onnx Flutter FFI plugin project.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"sherpa-onnx-node": "^1.10.13"
|
"sherpa-onnx-node": "^1.10.14"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ topics:
|
|||||||
- voice-activity-detection
|
- voice-activity-detection
|
||||||
|
|
||||||
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
|
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
|
||||||
version: 1.10.13
|
version: 1.10.14
|
||||||
|
|
||||||
homepage: https://github.com/k2-fsa/sherpa-onnx
|
homepage: https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
|
||||||
|
|||||||
6
scripts/whisper/.gitignore
vendored
6
scripts/whisper/.gitignore
vendored
@@ -2,3 +2,9 @@
|
|||||||
*.config
|
*.config
|
||||||
*.ort
|
*.ort
|
||||||
*-tokens.txt
|
*-tokens.txt
|
||||||
|
*.bias
|
||||||
|
*.weights
|
||||||
|
*.weight
|
||||||
|
*.*embedding
|
||||||
|
_Const*
|
||||||
|
onnx__*
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ from whisper.model import (
|
|||||||
TextDecoder,
|
TextDecoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -43,8 +46,9 @@ def get_args():
|
|||||||
choices=[
|
choices=[
|
||||||
"tiny", "tiny.en", "base", "base.en",
|
"tiny", "tiny.en", "base", "base.en",
|
||||||
"small", "small.en", "medium", "medium.en",
|
"small", "small.en", "medium", "medium.en",
|
||||||
"large", "large-v1", "large-v2",
|
"large", "large-v1", "large-v2", "large-v3",
|
||||||
"distil-medium.en", "distil-small.en", "distil-large-v2",
|
"distil-medium.en", "distil-small.en", "distil-large-v2",
|
||||||
|
# "distil-large-v3", # distil-large-v3 is not supported!
|
||||||
# for fine-tuned models from icefall
|
# for fine-tuned models from icefall
|
||||||
"medium-aishell",
|
"medium-aishell",
|
||||||
],
|
],
|
||||||
@@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
|||||||
Key-value pairs.
|
Key-value pairs.
|
||||||
"""
|
"""
|
||||||
model = onnx.load(filename)
|
model = onnx.load(filename)
|
||||||
|
|
||||||
|
while len(model.metadata_props):
|
||||||
|
model.metadata_props.pop()
|
||||||
|
|
||||||
for key, value in meta_data.items():
|
for key, value in meta_data.items():
|
||||||
meta = model.metadata_props.add()
|
meta = model.metadata_props.add()
|
||||||
meta.key = key
|
meta.key = key
|
||||||
meta.value = str(value)
|
meta.value = str(value)
|
||||||
|
|
||||||
onnx.save(model, filename)
|
if "large" in filename:
|
||||||
|
external_filename = filename.split(".onnx")[0]
|
||||||
|
onnx.save(
|
||||||
|
model,
|
||||||
|
filename,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=external_filename + ".weights",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
||||||
@@ -376,7 +394,9 @@ def main():
|
|||||||
|
|
||||||
# write tokens
|
# write tokens
|
||||||
|
|
||||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print(model.dims)
|
print(model.dims)
|
||||||
@@ -384,10 +404,15 @@ def main():
|
|||||||
audio = whisper.pad_or_trim(audio)
|
audio = whisper.pad_or_trim(audio)
|
||||||
assert audio.shape == (16000 * 30,), audio.shape
|
assert audio.shape == (16000 * 30,), audio.shape
|
||||||
|
|
||||||
# make log-Mel spectrogram and move to the same device as the model
|
if args.model in ("large", "large-v3"):
|
||||||
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
|
n_mels = 128
|
||||||
|
else:
|
||||||
|
n_mels = 80
|
||||||
|
mel = (
|
||||||
|
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
|
||||||
|
)
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
|
||||||
|
|
||||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||||
|
|
||||||
@@ -546,6 +571,17 @@ def main():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "large" in args.model:
|
||||||
|
decoder_external_filename = decoder_filename.split(".onnx")[0]
|
||||||
|
decoder_model = onnx.load(decoder_filename)
|
||||||
|
onnx.save(
|
||||||
|
decoder_model,
|
||||||
|
decoder_filename,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=decoder_external_filename + ".weights",
|
||||||
|
)
|
||||||
|
|
||||||
if "large" in args.model:
|
if "large" in args.model:
|
||||||
# it causes errors for large models, so skip it.
|
# it causes errors for large models, so skip it.
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -9,9 +9,10 @@ import base64
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import kaldi_native_fbank as knf
|
import kaldi_native_fbank as knf
|
||||||
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@@ -98,7 +99,6 @@ class OnnxModel:
|
|||||||
self.blank = int(meta["blank_id"])
|
self.blank = int(meta["blank_id"])
|
||||||
|
|
||||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||||
|
|
||||||
self.sot_sequence.append(self.no_timestamps)
|
self.sot_sequence.append(self.no_timestamps)
|
||||||
|
|
||||||
self.all_language_tokens = list(
|
self.all_language_tokens = list(
|
||||||
@@ -226,7 +226,18 @@ def load_tokens(filename):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def compute_features(filename: str) -> torch.Tensor:
|
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
data, sample_rate = sf.read(
|
||||||
|
filename,
|
||||||
|
always_2d=True,
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
data = data[:, 0] # use only the first channel
|
||||||
|
samples = np.ascontiguousarray(data)
|
||||||
|
return samples, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
filename:
|
filename:
|
||||||
@@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
|
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
|
||||||
"""
|
"""
|
||||||
wave, sample_rate = torchaudio.load(filename)
|
wave, sample_rate = load_audio(filename)
|
||||||
audio = wave[0].contiguous() # only use the first channel
|
|
||||||
if sample_rate != 16000:
|
if sample_rate != 16000:
|
||||||
audio = torchaudio.functional.resample(
|
import librosa
|
||||||
audio, orig_freq=sample_rate, new_freq=16000
|
|
||||||
)
|
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
opts = knf.WhisperFeatureOptions()
|
||||||
online_whisper_fbank.accept_waveform(16000, audio.numpy())
|
opts.dim = dim
|
||||||
|
online_whisper_fbank = knf.OnlineWhisperFbank(opts)
|
||||||
|
online_whisper_fbank.accept_waveform(16000, wave)
|
||||||
online_whisper_fbank.input_finished()
|
online_whisper_fbank.input_finished()
|
||||||
for i in range(online_whisper_fbank.num_frames_ready):
|
for i in range(online_whisper_fbank.num_frames_ready):
|
||||||
f = online_whisper_fbank.get_frame(i)
|
f = online_whisper_fbank.get_frame(i)
|
||||||
@@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
|
|||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
mel = compute_features(args.sound_file)
|
|
||||||
model = OnnxModel(args.encoder, args.decoder)
|
model = OnnxModel(args.encoder, args.decoder)
|
||||||
|
dim = 80 if "large-v3" not in args.encoder else 128
|
||||||
|
mel = compute_features(args.sound_file, dim=dim)
|
||||||
|
|
||||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||||
|
|
||||||
@@ -313,6 +327,7 @@ def main():
|
|||||||
|
|
||||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||||
|
|
||||||
|
print(model.sot_sequence)
|
||||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||||
offset = torch.zeros(1, dtype=torch.int64)
|
offset = torch.zeros(1, dtype=torch.int64)
|
||||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||||
|
|||||||
@@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
return std::make_unique<OfflineStream>(WhisperTag{});
|
WhisperTag tag;
|
||||||
|
tag.dim = model_->FeatureDim();
|
||||||
|
return std::make_unique<OfflineStream>(tag);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
|||||||
@@ -97,12 +97,16 @@ class OfflineStream::Impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit Impl(WhisperTag /*tag*/) {
|
explicit Impl(WhisperTag tag) {
|
||||||
config_.normalize_samples = true;
|
config_.normalize_samples = true;
|
||||||
opts_.frame_opts.samp_freq = 16000;
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
opts_.mel_opts.num_bins = 80; // not used
|
opts_.mel_opts.num_bins = tag.dim;
|
||||||
whisper_fbank_ =
|
|
||||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
knf::WhisperFeatureOptions whisper_opts;
|
||||||
|
whisper_opts.frame_opts = opts_.frame_opts;
|
||||||
|
whisper_opts.dim = tag.dim;
|
||||||
|
|
||||||
|
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
|
||||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ struct OfflineRecognitionResult {
|
|||||||
std::string AsJsonString() const;
|
std::string AsJsonString() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct WhisperTag {};
|
struct WhisperTag {
|
||||||
|
int32_t dim = 80;
|
||||||
|
};
|
||||||
|
|
||||||
struct CEDTag {};
|
struct CEDTag {};
|
||||||
|
|
||||||
class OfflineStream {
|
class OfflineStream {
|
||||||
|
|||||||
@@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl {
|
|||||||
|
|
||||||
int32_t VocabSize() const { return n_vocab_; }
|
int32_t VocabSize() const { return n_vocab_; }
|
||||||
|
|
||||||
|
int32_t FeatureDim() const { return n_mels_; }
|
||||||
|
|
||||||
int32_t Translate() const { return translate_; }
|
int32_t Translate() const { return translate_; }
|
||||||
|
|
||||||
bool IsMultiLingual() const { return is_multilingual_; }
|
bool IsMultiLingual() const { return is_multilingual_; }
|
||||||
@@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
||||||
@@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl {
|
|||||||
std::unordered_map<int32_t, std::string> id2lang_;
|
std::unordered_map<int32_t, std::string> id2lang_;
|
||||||
|
|
||||||
// model meta data
|
// model meta data
|
||||||
|
int32_t n_mels_ = 80;
|
||||||
int32_t n_text_layer_ = 0;
|
int32_t n_text_layer_ = 0;
|
||||||
int32_t n_text_ctx_ = 0;
|
int32_t n_text_ctx_ = 0;
|
||||||
int32_t n_text_state_ = 0;
|
int32_t n_text_state_ = 0;
|
||||||
@@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
|||||||
|
|
||||||
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
|
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); }
|
||||||
|
|
||||||
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
|
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
|
||||||
|
|
||||||
bool OfflineWhisperModel::IsMultiLingual() const {
|
bool OfflineWhisperModel::IsMultiLingual() const {
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class OfflineWhisperModel {
|
|||||||
int32_t SOT() const;
|
int32_t SOT() const;
|
||||||
int32_t TextCtx() const;
|
int32_t TextCtx() const;
|
||||||
int32_t VocabSize() const;
|
int32_t VocabSize() const;
|
||||||
|
int32_t FeatureDim() const;
|
||||||
int32_t Translate() const;
|
int32_t Translate() const;
|
||||||
bool IsMultiLingual() const;
|
bool IsMultiLingual() const;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user