diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml index 554e2065..e1592d84 100644 --- a/.github/workflows/export-whisper-to-onnx.yaml +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -16,32 +16,49 @@ jobs: fail-fast: false matrix: os: [macos-latest] - model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] + model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] + python-version: ["3.8"] steps: - uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies shell: bash run: | - python3 -m pip install openai-whisper torch onnxruntime onnx + python3 -m pip install torch==1.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html + python3 -m pip install openai-whisper==20230314 onnxruntime onnx - name: export ${{ matrix.model }} shell: bash run: | cd scripts/whisper + model=${{ matrix.model }} + echo "model: $model" + if [[ $model == distil-medium.en ]]; then + wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin + ls -lh + fi python3 ./export-onnx.py --model ${{ matrix.model }} python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ ls -lh - ls -lh ~/.cache/whisper + if [[ $model != distil-medium.en ]]; then + ls -lh ~/.cache/whisper + fi - name: Publish ${{ matrix.model }} to huggingface shell: bash env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | + model=${{ matrix.model }} + cd scripts/whisper git config --global user.email "csukuangfj@gmail.com" @@ -54,6 +71,18 @@ jobs: cp *tokens.txt ./huggingface cd huggingface + + if [[ $model == distil-medium.en ]]; then + mkdir test_wavs + cd test_wavs + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt + git add . + cd .. + fi + git status ls -lh git lfs track "*.onnx" diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index 46594d12..945d39cc 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -39,7 +39,9 @@ def get_args(): choices=[ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", - "large", "large-v1", "large-v2"], + "large", "large-v1", "large-v2", + "distil-medium.en", + ], # fmt: on ) return parser.parse_args() @@ -257,10 +259,27 @@ def convert_tokens(name, model): def main(): args = get_args() name = args.model + print(args) + print(name) opset_version = 13 - model = whisper.load_model(name) + if name == "distil-medium.en": + filename = "./distil-medium-en-original-model.bin" + if not Path(filename): + raise ValueError( + """ + Please go to https://huggingface.co/distil-whisper/distil-medium.en + to download original-model.bin + You can use the following command to do that: + + wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin + """ + ) + model = whisper.load_model(filename) + else: + model = whisper.load_model(name) + print( f"number of model parameters: {name}", sum(p.numel() for p in model.parameters()),