Add CI test for Whisper models (#239)
This commit is contained in:
85
.github/scripts/test-offline-whisper.sh
vendored
Executable file
85
.github/scripts/test-offline-whisper.sh
vendored
Executable file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
names=(
|
||||
tiny.en
|
||||
base.en
|
||||
# small.en
|
||||
# medium.en
|
||||
)
|
||||
|
||||
for name in ${names[@]}; do
|
||||
log "------------------------------------------------------------"
|
||||
log "Run $name"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
git lfs pull --include "*.ort"
|
||||
ls -lh *.{onnx,ort}
|
||||
popd
|
||||
|
||||
log "test fp32 onnx"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.onnx \
|
||||
--whisper-decoder=$repo/${name}-decoder.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
log "test int8 onnx"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
|
||||
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
log "test fp32 ort"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.ort \
|
||||
--whisper-decoder=$repo/${name}-decoder.ort \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
log "test int8 ort"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.int8.ort \
|
||||
--whisper-decoder=$repo/${name}-decoder.int8.ort \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
done
|
||||
8
.github/workflows/linux.yaml
vendored
8
.github/workflows/linux.yaml
vendored
@@ -84,6 +84,14 @@ jobs:
|
||||
file build/bin/sherpa-onnx
|
||||
readelf -d build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline Whisper
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/macos.yaml
vendored
8
.github/workflows/macos.yaml
vendored
@@ -82,6 +82,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline Whisper
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x64-cuda.yaml
vendored
8
.github/workflows/windows-x64-cuda.yaml
vendored
@@ -74,6 +74,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline Whisper for windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline CTC for windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x64.yaml
vendored
8
.github/workflows/windows-x64.yaml
vendored
@@ -75,6 +75,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline Whisper for windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline CTC for windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
8
.github/workflows/windows-x86.yaml
vendored
8
.github/workflows/windows-x86.yaml
vendored
@@ -73,6 +73,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline Whisper for windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline CTC for windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -5,5 +5,9 @@ and use onnxruntime to replace PyTorch for speech recognition.
|
||||
|
||||
You can use [sherpa-onnx][sherpa-onnx] to run the converted model.
|
||||
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html
|
||||
for details.
|
||||
|
||||
[whisper]: https://github.com/openai/whisper
|
||||
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
|
||||
|
||||
@@ -18,15 +18,30 @@ import argparse
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"--encoder",
|
||||
type=str,
|
||||
required=True,
|
||||
# fmt: off
|
||||
choices=[
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2"],
|
||||
# fmt: on
|
||||
help="Path to the encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the tokens",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="Path to the test wave",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -161,11 +176,10 @@ def load_tokens(filename):
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
name = args.model
|
||||
encoder = args.encoder
|
||||
decoder = args.decoder
|
||||
|
||||
encoder = f"./{name}-encoder.onnx"
|
||||
decoder = f"./{name}-decoder.onnx"
|
||||
audio = whisper.load_audio("0.wav")
|
||||
audio = whisper.load_audio(args.sound_file)
|
||||
|
||||
features = []
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||
@@ -224,17 +238,13 @@ def main():
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=False)
|
||||
max_token_id = logits.argmax(dim=-1)
|
||||
token_table = load_tokens(f"./{name}-tokens.txt")
|
||||
token_table = load_tokens(args.tokens)
|
||||
s = b""
|
||||
for i in results:
|
||||
if i in token_table:
|
||||
s += base64.b64decode(token_table[i])
|
||||
else:
|
||||
print("oov", i)
|
||||
|
||||
print(s.decode().strip())
|
||||
print(results)
|
||||
print(model.sot_sequence)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user