Remove the 30-second constraint from whisper. (#471)
This commit is contained in:
36
.github/scripts/test-offline-whisper.sh
vendored
36
.github/scripts/test-offline-whisper.sh
vendored
@@ -16,8 +16,12 @@ which $EXE
|
||||
names=(
|
||||
tiny.en
|
||||
base.en
|
||||
# small.en
|
||||
# medium.en
|
||||
small.en
|
||||
medium.en
|
||||
tiny
|
||||
base
|
||||
small
|
||||
medium
|
||||
)
|
||||
|
||||
for name in ${names[@]}; do
|
||||
@@ -33,8 +37,8 @@ for name in ${names[@]}; do
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
git lfs pull --include "*.ort"
|
||||
ls -lh *.{onnx,ort}
|
||||
# git lfs pull --include "*.ort"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
log "test fp32 onnx"
|
||||
@@ -43,6 +47,7 @@ for name in ${names[@]}; do
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.onnx \
|
||||
--whisper-decoder=$repo/${name}-decoder.onnx \
|
||||
--whisper-tail-paddings=500 \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
@@ -54,28 +59,7 @@ for name in ${names[@]}; do
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
|
||||
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
log "test fp32 ort"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.ort \
|
||||
--whisper-decoder=$repo/${name}-decoder.ort \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
log "test int8 ort"
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/${name}-tokens.txt \
|
||||
--whisper-encoder=$repo/${name}-encoder.int8.ort \
|
||||
--whisper-decoder=$repo/${name}-decoder.int8.ort \
|
||||
--whisper-tail-paddings=500 \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
|
||||
60
.github/workflows/export-whisper-to-onnx.yaml
vendored
60
.github/workflows/export-whisper-to-onnx.yaml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
os: [ubuntu-latest]
|
||||
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
||||
python-version: ["3.8"]
|
||||
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
ls -lh
|
||||
fi
|
||||
python3 ./export-onnx.py --model ${{ matrix.model }}
|
||||
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
||||
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
||||
|
||||
ls -lh
|
||||
|
||||
@@ -52,41 +52,61 @@ jobs:
|
||||
ls -lh ~/.cache/whisper
|
||||
fi
|
||||
|
||||
src=sherpa-onnx-whisper-${{ matrix.model }}
|
||||
|
||||
mkdir $src
|
||||
cp *.onnx $src/
|
||||
cp *tokens.txt $src
|
||||
|
||||
cd $src
|
||||
mkdir -p test_wavs
|
||||
cd test_wavs
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
|
||||
cd ../..
|
||||
mv $src ../..
|
||||
|
||||
cd ../..
|
||||
echo "--------------------"
|
||||
ls -lh
|
||||
ls -lh $src
|
||||
echo "--------------------"
|
||||
|
||||
tar cjvf ./$src.tar.bz2 $src
|
||||
|
||||
- name: Release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
file: ./*.tar.bz2
|
||||
overwrite: true
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: asr-models
|
||||
|
||||
- name: Publish ${{ matrix.model }} to huggingface
|
||||
shell: bash
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
model=${{ matrix.model }}
|
||||
|
||||
cd scripts/whisper
|
||||
src=sherpa-onnx-whisper-${{ matrix.model }}
|
||||
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
||||
rm -rf huggingface/*
|
||||
|
||||
cp *.onnx ./huggingface
|
||||
cp *.ort ./huggingface
|
||||
cp *tokens.txt ./huggingface
|
||||
cp -av $src/* ./huggingface/
|
||||
|
||||
cd huggingface
|
||||
|
||||
if [[ $model == distil-medium.en ]]; then
|
||||
mkdir test_wavs
|
||||
cd test_wavs
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
|
||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
|
||||
git add .
|
||||
cd ..
|
||||
fi
|
||||
|
||||
git status
|
||||
ls -lh
|
||||
git lfs track "*.onnx"
|
||||
git lfs track "*.ort"
|
||||
# git lfs track "*.ort"
|
||||
git add .
|
||||
git commit -m "upload ${{ matrix.model }}"
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
||||
|
||||
20
.github/workflows/linux.yaml
vendored
20
.github/workflows/linux.yaml
vendored
@@ -107,6 +107,16 @@ jobs:
|
||||
name: release-static
|
||||
path: build/bin/*
|
||||
|
||||
- name: Test offline Whisper
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
readelf -d build/bin/sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -139,16 +149,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-online-paraformer.sh
|
||||
|
||||
- name: Test offline Whisper
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
readelf -d build/bin/sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
14
.github/workflows/windows-x86.yaml
vendored
14
.github/workflows/windows-x86.yaml
vendored
@@ -93,13 +93,13 @@ jobs:
|
||||
|
||||
.github/scripts/test-online-paraformer.sh
|
||||
|
||||
- name: Test offline Whisper for windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-whisper.sh
|
||||
# - name: Test offline Whisper for windows x86
|
||||
# shell: bash
|
||||
# run: |
|
||||
# export PATH=$PWD/build/bin/Release:$PATH
|
||||
# export EXE=sherpa-onnx-offline.exe
|
||||
#
|
||||
# .github/scripts/test-offline-whisper.sh
|
||||
|
||||
- name: Test offline CTC for windows x86
|
||||
shell: bash
|
||||
|
||||
@@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
|
||||
|
||||
Thanks to https://github.com/TadaoYamaoka
|
||||
for making the onnx export script public.
|
||||
|
||||
Note that we have removed the 30 seconds constraint from whisper. You can
|
||||
use any T <= 30.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -17,6 +20,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
if False:
|
||||
# This branch contains the original code
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
else:
|
||||
# This branch contains the actual changes
|
||||
assert (
|
||||
x.shape[2] == self.positional_embedding.shape[1]
|
||||
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
|
||||
assert (
|
||||
x.shape[1] == self.positional_embedding.shape[0]
|
||||
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
|
||||
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
AudioEncoder.forward = modified_audio_encoder_forward
|
||||
|
||||
|
||||
class AudioEncoderTensorCache(nn.Module):
|
||||
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
|
||||
super().__init__()
|
||||
@@ -279,6 +316,7 @@ def main():
|
||||
model = whisper.load_model(filename)
|
||||
else:
|
||||
model = whisper.load_model(name)
|
||||
print(model.dims)
|
||||
|
||||
print(
|
||||
f"number of model parameters: {name}",
|
||||
@@ -311,19 +349,20 @@ def main():
|
||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||
|
||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||
|
||||
n_layer_cross_k, n_layer_cross_v = encoder(mel)
|
||||
assert n_layer_cross_k.shape == (
|
||||
model.dims.n_text_layer,
|
||||
batch_size,
|
||||
model.dims.n_audio_ctx,
|
||||
model.dims.n_text_state,
|
||||
), n_layer_cross_k.shape
|
||||
), (n_layer_cross_k.shape, model.dims)
|
||||
assert n_layer_cross_v.shape == (
|
||||
model.dims.n_text_layer,
|
||||
batch_size,
|
||||
model.dims.n_audio_ctx,
|
||||
model.dims.n_text_state,
|
||||
), n_layer_cross_v.shape
|
||||
), (n_layer_cross_v.shape, model.dims)
|
||||
|
||||
encoder_filename = f"{name}-encoder.onnx"
|
||||
torch.onnx.export(
|
||||
@@ -334,9 +373,9 @@ def main():
|
||||
input_names=["mel"],
|
||||
output_names=["n_layer_cross_k", "n_layer_cross_v"],
|
||||
dynamic_axes={
|
||||
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
|
||||
"n_layer_cross_k": {1: "n_audio"},
|
||||
"n_layer_cross_v": {1: "n_audio"},
|
||||
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
|
||||
"n_layer_cross_k": {1: "n_audio", 2: "T"},
|
||||
"n_layer_cross_v": {1: "n_audio", 2: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -461,8 +500,8 @@ def main():
|
||||
"tokens": {0: "n_audio", 1: "n_tokens"},
|
||||
"in_n_layer_self_k_cache": {1: "n_audio"},
|
||||
"in_n_layer_self_v_cache": {1: "n_audio"},
|
||||
"n_layer_cross_k": {1: "n_audio"},
|
||||
"n_layer_cross_v": {1: "n_audio"},
|
||||
"n_layer_cross_k": {1: "n_audio", 2: "T"},
|
||||
"n_layer_cross_v": {1: "n_audio", 2: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
mel = (log_spec + 4.0) / 4.0
|
||||
# mel (T, 80)
|
||||
|
||||
# We pad 50 frames at the end so that it is able to detect eot
|
||||
# You can use another value instead of 50.
|
||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
|
||||
# Note that if it throws for a multilingual model,
|
||||
# please use a larger value, say 300
|
||||
|
||||
target = 3000
|
||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||
if mel.shape[0] > target:
|
||||
mel = mel[:target]
|
||||
|
||||
# We don't need to pad it to 30 seconds now!
|
||||
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||
|
||||
mel = mel.t().unsqueeze(0)
|
||||
|
||||
return mel
|
||||
|
||||
@@ -115,7 +115,27 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
|
||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
|
||||
// note that 50 is an experience value.
|
||||
// see also ../../scripts/whisper/test.py
|
||||
//
|
||||
// You can replace 50 by other values, say, 100.
|
||||
//
|
||||
// Since we have removed the 30 seconds constraint, we need
|
||||
// tail_padding_frames so that whisper is able to detect the eot token.
|
||||
int32_t tail_padding_frames = 50;
|
||||
if (model_->IsMultiLingual()) {
|
||||
// 300 is an experience value. If it throws, please use a larger value.
|
||||
tail_padding_frames = 300;
|
||||
}
|
||||
|
||||
if (config_.model_config.whisper.tail_paddings > 0) {
|
||||
tail_padding_frames = config_.model_config.whisper.tail_paddings;
|
||||
}
|
||||
|
||||
int32_t actual_frames =
|
||||
std::min(num_frames + tail_padding_frames, max_num_frames);
|
||||
|
||||
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
|
||||
|
||||
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||
model_->Allocator(), shape.data(), shape.size());
|
||||
@@ -123,7 +143,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
std::copy(f.begin(), f.end(), p_mel);
|
||||
|
||||
memset(p_mel + f.size(), 0,
|
||||
(max_num_frames - num_frames) * feat_dim * sizeof(float));
|
||||
(actual_frames - num_frames) * feat_dim * sizeof(float));
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
try {
|
||||
|
||||
@@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||
"Valid values: transcribe, translate. "
|
||||
"Note that for non-multilingual models, it supports "
|
||||
"only 'transcribe'");
|
||||
|
||||
po->Register(
|
||||
"whisper-tail-paddings", &tail_paddings,
|
||||
"Suggest value: 50 for English models. 300 for multilingual models. "
|
||||
"Since we have removed the 30-second constraint, we need to add some "
|
||||
"tail padding frames "
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
|
||||
"English models and 300 for multilingual models.");
|
||||
}
|
||||
|
||||
bool OfflineWhisperModelConfig::Validate() const {
|
||||
@@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const {
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "language=\"" << language << "\", ";
|
||||
os << "task=\"" << task << "\")";
|
||||
os << "task=\"" << task << "\", ";
|
||||
os << "tail_paddings=" << tail_paddings << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig {
|
||||
// Note: For non-multilingual models, it supports only "transcribe"
|
||||
std::string task = "transcribe";
|
||||
|
||||
// Number of tail padding frames.
|
||||
//
|
||||
// Since we remove the 30-second constraint, we need to add some paddings
|
||||
// at the end.
|
||||
//
|
||||
// Recommended values:
|
||||
// - 50 for English models
|
||||
// - 300 for multilingual models
|
||||
int32_t tail_paddings = -1;
|
||||
|
||||
OfflineWhisperModelConfig() = default;
|
||||
OfflineWhisperModelConfig(const std::string &encoder,
|
||||
const std::string &decoder,
|
||||
const std::string &language,
|
||||
const std::string &task)
|
||||
: encoder(encoder), decoder(decoder), language(language), task(task) {}
|
||||
const std::string &task, int32_t tail_paddings)
|
||||
: encoder(encoder),
|
||||
decoder(decoder),
|
||||
language(language),
|
||||
task(task),
|
||||
tail_paddings(tail_paddings) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) {
|
||||
using PyClass = OfflineWhisperModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string &>(),
|
||||
const std::string &, const std::string &, int32_t>(),
|
||||
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
|
||||
py::arg("task"))
|
||||
py::arg("task"), py::arg("tail_paddings") = -1)
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def_readwrite("language", &PyClass::language)
|
||||
.def_readwrite("task", &PyClass::task)
|
||||
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user