From ef8d112aaaa0509bba31f6bbb8eaf8e805c21613 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 20 Dec 2023 11:12:12 +0800 Subject: [PATCH] Fix whisper test script for the latest onnxruntime. (#494) --- .github/workflows/export-whisper-to-onnx.yaml | 15 ++++++++++++++- build-apk-two-pass.sh | 4 ++-- scripts/whisper/test.py | 2 ++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml index d77a6fe9..d0001861 100644 --- a/.github/workflows/export-whisper-to-onnx.yaml +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -31,7 +31,7 @@ jobs: - name: Install dependencies shell: bash run: | - python3 -m pip install torch==1.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html + python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html python3 -m pip install openai-whisper==20230314 onnxruntime onnx - name: export ${{ matrix.model }} @@ -108,6 +108,19 @@ jobs: repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} tag: asr-models + - name: Test ${{ matrix.model }} + shell: bash + run: | + python3 -m pip install kaldi-native-fbank + git checkout . + model=${{ matrix.model }} + src=sherpa-onnx-whisper-$model + python3 scripts/whisper/test.py \ + --encoder $src/$model-encoder.int8.onnx \ + --decoder $src/$model-decoder.int8.onnx \ + --tokens $src/$model-tokens.txt \ + $src/test_wavs/0.wav + - name: Publish ${{ matrix.model }} to huggingface shell: bash env: diff --git a/build-apk-two-pass.sh b/build-apk-two-pass.sh index e83c0706..20bd6d43 100755 --- a/build-apk-two-pass.sh +++ b/build-apk-two-pass.sh @@ -74,11 +74,11 @@ git lfs pull --include "*.onnx" # remove .git to save spaces rm -rf .git -rm README.md +rm -fv README.md rm -rf test_wavs rm .gitattributes -rm *.ort +rm -fv *.ort rm tiny.en-encoder.onnx rm tiny.en-decoder.onnx diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 7acb57fe..5f3d02c2 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -82,6 +82,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) meta = self.encoder.get_modelmeta().custom_metadata_map @@ -113,6 +114,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def run_encoder(