Support distil-small.en whisper (#472)
This commit is contained in:
2
.github/scripts/test-offline-whisper.sh
vendored
2
.github/scripts/test-offline-whisper.sh
vendored
@@ -22,6 +22,8 @@ tiny
|
|||||||
base
|
base
|
||||||
small
|
small
|
||||||
medium
|
medium
|
||||||
|
distil-medium.en
|
||||||
|
distil-small.en
|
||||||
)
|
)
|
||||||
|
|
||||||
for name in ${names[@]}; do
|
for name in ${names[@]}; do
|
||||||
|
|||||||
59
.github/workflows/export-whisper-to-onnx.yaml
vendored
59
.github/workflows/export-whisper-to-onnx.yaml
vendored
@@ -15,8 +15,9 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [macos-latest]
|
||||||
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
# model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
|
||||||
|
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"]
|
||||||
python-version: ["3.8"]
|
python-version: ["3.8"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -42,23 +43,33 @@ jobs:
|
|||||||
if [[ $model == distil-medium.en ]]; then
|
if [[ $model == distil-medium.en ]]; then
|
||||||
wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
|
wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
|
||||||
ls -lh
|
ls -lh
|
||||||
|
elif [[ $model == distil-large-v2 ]]; then
|
||||||
|
wget -q -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
|
||||||
|
ls -lh
|
||||||
|
elif [[ $model == distil-small.en ]]; then
|
||||||
|
wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
|
||||||
|
ls -lh
|
||||||
fi
|
fi
|
||||||
python3 ./export-onnx.py --model ${{ matrix.model }}
|
python3 ./export-onnx.py --model ${{ matrix.model }}
|
||||||
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
||||||
|
|
||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
if [[ $model != distil-medium.en ]]; then
|
ls -lh ~/.cache/whisper || true
|
||||||
ls -lh ~/.cache/whisper
|
ls -lh distil*original-model.bin || true
|
||||||
fi
|
rm -rf ~/.cache/whisper
|
||||||
|
rm -f distil*original-model.bin
|
||||||
|
|
||||||
src=sherpa-onnx-whisper-${{ matrix.model }}
|
src=sherpa-onnx-whisper-${{ matrix.model }}
|
||||||
|
|
||||||
mkdir $src
|
cd ..
|
||||||
cp *.onnx $src/
|
mv whisper $src
|
||||||
cp *tokens.txt $src
|
|
||||||
|
echo "------------------------------"
|
||||||
|
|
||||||
cd $src
|
cd $src
|
||||||
|
du -h -d1 .
|
||||||
|
ls -lh
|
||||||
mkdir -p test_wavs
|
mkdir -p test_wavs
|
||||||
cd test_wavs
|
cd test_wavs
|
||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
|
||||||
@@ -66,21 +77,32 @@ jobs:
|
|||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
|
||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
|
||||||
cd ../..
|
cd ../..
|
||||||
mv $src ../..
|
mv $src ../
|
||||||
|
echo "pwd: $PWD"
|
||||||
|
|
||||||
cd ../..
|
cd ../
|
||||||
echo "--------------------"
|
echo "--------------------"
|
||||||
ls -lh
|
ls -lh
|
||||||
ls -lh $src
|
ls -lh $src
|
||||||
echo "--------------------"
|
echo "--------------------"
|
||||||
|
|
||||||
tar cjvf ./$src.tar.bz2 $src
|
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
|
||||||
|
#tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2.
|
||||||
|
tar cvjf $src.tar.bz2 $src
|
||||||
|
split -b 1G $src.tar.bz2 $src.tar.bz2.
|
||||||
|
rm $src.tar.bz2
|
||||||
|
# cat $src.tar.gz.* | tar xjf -
|
||||||
|
else
|
||||||
|
tar cvjf $src.tar.bz2 $src
|
||||||
|
fi
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
file: ./*.tar.bz2
|
file: ./*.tar*
|
||||||
overwrite: true
|
overwrite: true
|
||||||
repo_name: k2-fsa/sherpa-onnx
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
@@ -99,14 +121,21 @@ jobs:
|
|||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
||||||
rm -rf huggingface/*
|
rm -rf huggingface/*
|
||||||
|
|
||||||
cp -av $src/* ./huggingface/
|
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
|
||||||
|
mv $src.tar* ./huggingface
|
||||||
|
else
|
||||||
|
cp -v $src/*.onnx ./huggingface
|
||||||
|
cp -v $src/*tokens* ./huggingface
|
||||||
|
cp -av $src/test_wavs ./huggingface
|
||||||
|
fi
|
||||||
|
|
||||||
cd huggingface
|
cd huggingface
|
||||||
|
|
||||||
git status
|
git status
|
||||||
ls -lh
|
ls -lh
|
||||||
git lfs track "*.onnx"
|
git lfs track "*gz*"
|
||||||
# git lfs track "*.ort"
|
git lfs track "*onnx*"
|
||||||
|
|
||||||
git add .
|
git add .
|
||||||
git commit -m "upload ${{ matrix.model }}"
|
git commit -m "upload ${{ matrix.model }}"
|
||||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ jobs:
|
|||||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
|
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
|
||||||
|
|
||||||
- name: Start server for paraformer models
|
- name: Start server for paraformer models
|
||||||
if: matrix.model_type == 'paraformer'
|
if: matrix.model_type == 'paraformer' && matrix.os != 'windows-latest'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
|
||||||
@@ -106,7 +106,7 @@ jobs:
|
|||||||
sleep 10
|
sleep 10
|
||||||
|
|
||||||
- name: Start client for paraformer models
|
- name: Start client for paraformer models
|
||||||
if: matrix.model_type == 'paraformer'
|
if: matrix.model_type == 'paraformer' && matrix.os != 'windows-latest'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.9.0")
|
set(SHERPA_ONNX_VERSION "1.9.1")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ def get_args():
|
|||||||
"tiny", "tiny.en", "base", "base.en",
|
"tiny", "tiny.en", "base", "base.en",
|
||||||
"small", "small.en", "medium", "medium.en",
|
"small", "small.en", "medium", "medium.en",
|
||||||
"large", "large-v1", "large-v2",
|
"large", "large-v1", "large-v2",
|
||||||
"distil-medium.en",
|
"distil-medium.en", "distil-small.en", "distil-large-v2"
|
||||||
],
|
],
|
||||||
# fmt: on
|
# fmt: on
|
||||||
)
|
)
|
||||||
@@ -314,6 +314,32 @@ def main():
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
model = whisper.load_model(filename)
|
model = whisper.load_model(filename)
|
||||||
|
elif name == "distil-large-v2":
|
||||||
|
filename = "./distil-large-v2-original-model.bin"
|
||||||
|
if not Path(filename).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""
|
||||||
|
Please go to https://huggingface.co/distil-whisper/distil-large-v2
|
||||||
|
to download original-model.bin
|
||||||
|
You can use the following command to do that:
|
||||||
|
|
||||||
|
wget -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
model = whisper.load_model(filename)
|
||||||
|
elif name == "distil-small.en":
|
||||||
|
filename = "./distil-small-en-original-model.bin"
|
||||||
|
if not Path(filename).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
"""
|
||||||
|
Please go to https://huggingface.co/distil-whisper/distil-small.en
|
||||||
|
to download original-model.bin
|
||||||
|
You can use the following command to do that:
|
||||||
|
|
||||||
|
wget -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
model = whisper.load_model(filename)
|
||||||
else:
|
else:
|
||||||
model = whisper.load_model(name)
|
model = whisper.load_model(name)
|
||||||
print(model.dims)
|
print(model.dims)
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ class OnnxModel:
|
|||||||
logits = logits.reshape(-1)
|
logits = logits.reshape(-1)
|
||||||
mask = torch.ones(logits.shape[0], dtype=torch.int64)
|
mask = torch.ones(logits.shape[0], dtype=torch.int64)
|
||||||
mask[self.all_language_tokens] = 0
|
mask[self.all_language_tokens] = 0
|
||||||
logits[mask] = float("-inf")
|
logits[mask != 0] = float("-inf")
|
||||||
lang_id = logits.argmax().item()
|
lang_id = logits.argmax().item()
|
||||||
print("detected language: ", self.id2lang[lang_id])
|
print("detected language: ", self.id2lang[lang_id])
|
||||||
return lang_id
|
return lang_id
|
||||||
@@ -263,7 +263,9 @@ def compute_features(filename: str) -> torch.Tensor:
|
|||||||
|
|
||||||
target = 3000
|
target = 3000
|
||||||
if mel.shape[0] > target:
|
if mel.shape[0] > target:
|
||||||
mel = mel[:target]
|
# -50 so that there are some zero tail paddings.
|
||||||
|
mel = mel[: target - 50]
|
||||||
|
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
|
||||||
|
|
||||||
# We don't need to pad it to 30 seconds now!
|
# We don't need to pad it to 30 seconds now!
|
||||||
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||||
|
|||||||
@@ -106,11 +106,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
std::vector<float> f = s->GetFrames();
|
std::vector<float> f = s->GetFrames();
|
||||||
int32_t num_frames = f.size() / feat_dim;
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
if (num_frames > max_num_frames) {
|
// we use 50 here so that there will be some zero tail paddings
|
||||||
|
if (num_frames >= max_num_frames - 50) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Only waves less than 30 seconds are supported. We process only the "
|
"Only waves less than 30 seconds are supported. We process only the "
|
||||||
"first 30 seconds and discard the remaining data");
|
"first 30 seconds and discard the remaining data");
|
||||||
num_frames = max_num_frames;
|
num_frames = max_num_frames - 50;
|
||||||
}
|
}
|
||||||
|
|
||||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||||
@@ -140,7 +141,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||||
model_->Allocator(), shape.data(), shape.size());
|
model_->Allocator(), shape.data(), shape.size());
|
||||||
float *p_mel = mel.GetTensorMutableData<float>();
|
float *p_mel = mel.GetTensorMutableData<float>();
|
||||||
std::copy(f.begin(), f.end(), p_mel);
|
std::copy(f.data(), f.data() + actual_frames * feat_dim, p_mel);
|
||||||
|
|
||||||
memset(p_mel + f.size(), 0,
|
memset(p_mel + f.size(), 0,
|
||||||
(actual_frames - num_frames) * feat_dim * sizeof(float));
|
(actual_frames - num_frames) * feat_dim * sizeof(float));
|
||||||
|
|||||||
Reference in New Issue
Block a user