Remove the 30-second constraint from whisper. (#471)
This commit is contained in:
36
.github/scripts/test-offline-whisper.sh
vendored
36
.github/scripts/test-offline-whisper.sh
vendored
@@ -16,8 +16,12 @@ which $EXE
|
|||||||
names=(
|
names=(
|
||||||
tiny.en
|
tiny.en
|
||||||
base.en
|
base.en
|
||||||
# small.en
|
small.en
|
||||||
# medium.en
|
medium.en
|
||||||
|
tiny
|
||||||
|
base
|
||||||
|
small
|
||||||
|
medium
|
||||||
)
|
)
|
||||||
|
|
||||||
for name in ${names[@]}; do
|
for name in ${names[@]}; do
|
||||||
@@ -33,8 +37,8 @@ for name in ${names[@]}; do
|
|||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "*.onnx"
|
git lfs pull --include "*.onnx"
|
||||||
git lfs pull --include "*.ort"
|
# git lfs pull --include "*.ort"
|
||||||
ls -lh *.{onnx,ort}
|
ls -lh *.onnx
|
||||||
popd
|
popd
|
||||||
|
|
||||||
log "test fp32 onnx"
|
log "test fp32 onnx"
|
||||||
@@ -43,6 +47,7 @@ for name in ${names[@]}; do
|
|||||||
--tokens=$repo/${name}-tokens.txt \
|
--tokens=$repo/${name}-tokens.txt \
|
||||||
--whisper-encoder=$repo/${name}-encoder.onnx \
|
--whisper-encoder=$repo/${name}-encoder.onnx \
|
||||||
--whisper-decoder=$repo/${name}-decoder.onnx \
|
--whisper-decoder=$repo/${name}-decoder.onnx \
|
||||||
|
--whisper-tail-paddings=500 \
|
||||||
--num-threads=2 \
|
--num-threads=2 \
|
||||||
$repo/test_wavs/0.wav \
|
$repo/test_wavs/0.wav \
|
||||||
$repo/test_wavs/1.wav \
|
$repo/test_wavs/1.wav \
|
||||||
@@ -54,28 +59,7 @@ for name in ${names[@]}; do
|
|||||||
--tokens=$repo/${name}-tokens.txt \
|
--tokens=$repo/${name}-tokens.txt \
|
||||||
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
|
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
|
||||||
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
|
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
|
||||||
--num-threads=2 \
|
--whisper-tail-paddings=500 \
|
||||||
$repo/test_wavs/0.wav \
|
|
||||||
$repo/test_wavs/1.wav \
|
|
||||||
$repo/test_wavs/8k.wav
|
|
||||||
|
|
||||||
log "test fp32 ort"
|
|
||||||
|
|
||||||
time $EXE \
|
|
||||||
--tokens=$repo/${name}-tokens.txt \
|
|
||||||
--whisper-encoder=$repo/${name}-encoder.ort \
|
|
||||||
--whisper-decoder=$repo/${name}-decoder.ort \
|
|
||||||
--num-threads=2 \
|
|
||||||
$repo/test_wavs/0.wav \
|
|
||||||
$repo/test_wavs/1.wav \
|
|
||||||
$repo/test_wavs/8k.wav
|
|
||||||
|
|
||||||
log "test int8 ort"
|
|
||||||
|
|
||||||
time $EXE \
|
|
||||||
--tokens=$repo/${name}-tokens.txt \
|
|
||||||
--whisper-encoder=$repo/${name}-encoder.int8.ort \
|
|
||||||
--whisper-decoder=$repo/${name}-decoder.int8.ort \
|
|
||||||
--num-threads=2 \
|
--num-threads=2 \
|
||||||
$repo/test_wavs/0.wav \
|
$repo/test_wavs/0.wav \
|
||||||
$repo/test_wavs/1.wav \
|
$repo/test_wavs/1.wav \
|
||||||
|
|||||||
60
.github/workflows/export-whisper-to-onnx.yaml
vendored
60
.github/workflows/export-whisper-to-onnx.yaml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest]
|
os: [ubuntu-latest]
|
||||||
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
||||||
python-version: ["3.8"]
|
python-version: ["3.8"]
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
ls -lh
|
ls -lh
|
||||||
fi
|
fi
|
||||||
python3 ./export-onnx.py --model ${{ matrix.model }}
|
python3 ./export-onnx.py --model ${{ matrix.model }}
|
||||||
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
||||||
|
|
||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
@@ -52,41 +52,61 @@ jobs:
|
|||||||
ls -lh ~/.cache/whisper
|
ls -lh ~/.cache/whisper
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
src=sherpa-onnx-whisper-${{ matrix.model }}
|
||||||
|
|
||||||
|
mkdir $src
|
||||||
|
cp *.onnx $src/
|
||||||
|
cp *tokens.txt $src
|
||||||
|
|
||||||
|
cd $src
|
||||||
|
mkdir -p test_wavs
|
||||||
|
cd test_wavs
|
||||||
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
|
||||||
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
|
||||||
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
|
||||||
|
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
|
||||||
|
cd ../..
|
||||||
|
mv $src ../..
|
||||||
|
|
||||||
|
cd ../..
|
||||||
|
echo "--------------------"
|
||||||
|
ls -lh
|
||||||
|
ls -lh $src
|
||||||
|
echo "--------------------"
|
||||||
|
|
||||||
|
tar cjvf ./$src.tar.bz2 $src
|
||||||
|
|
||||||
|
- name: Release
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: ./*.tar.bz2
|
||||||
|
overwrite: true
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: asr-models
|
||||||
|
|
||||||
- name: Publish ${{ matrix.model }} to huggingface
|
- name: Publish ${{ matrix.model }} to huggingface
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
model=${{ matrix.model }}
|
src=sherpa-onnx-whisper-${{ matrix.model }}
|
||||||
|
|
||||||
cd scripts/whisper
|
|
||||||
|
|
||||||
git config --global user.email "csukuangfj@gmail.com"
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
git config --global user.name "Fangjun Kuang"
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
||||||
|
rm -rf huggingface/*
|
||||||
|
|
||||||
cp *.onnx ./huggingface
|
cp -av $src/* ./huggingface/
|
||||||
cp *.ort ./huggingface
|
|
||||||
cp *tokens.txt ./huggingface
|
|
||||||
|
|
||||||
cd huggingface
|
cd huggingface
|
||||||
|
|
||||||
if [[ $model == distil-medium.en ]]; then
|
|
||||||
mkdir test_wavs
|
|
||||||
cd test_wavs
|
|
||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
|
|
||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
|
|
||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
|
|
||||||
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
|
|
||||||
git add .
|
|
||||||
cd ..
|
|
||||||
fi
|
|
||||||
|
|
||||||
git status
|
git status
|
||||||
ls -lh
|
ls -lh
|
||||||
git lfs track "*.onnx"
|
git lfs track "*.onnx"
|
||||||
git lfs track "*.ort"
|
# git lfs track "*.ort"
|
||||||
git add .
|
git add .
|
||||||
git commit -m "upload ${{ matrix.model }}"
|
git commit -m "upload ${{ matrix.model }}"
|
||||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
||||||
|
|||||||
20
.github/workflows/linux.yaml
vendored
20
.github/workflows/linux.yaml
vendored
@@ -107,6 +107,16 @@ jobs:
|
|||||||
name: release-static
|
name: release-static
|
||||||
path: build/bin/*
|
path: build/bin/*
|
||||||
|
|
||||||
|
- name: Test offline Whisper
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline
|
||||||
|
|
||||||
|
readelf -d build/bin/sherpa-onnx-offline
|
||||||
|
|
||||||
|
.github/scripts/test-offline-whisper.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -139,16 +149,6 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-online-paraformer.sh
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
export PATH=$PWD/build/bin:$PATH
|
|
||||||
export EXE=sherpa-onnx-offline
|
|
||||||
|
|
||||||
readelf -d build/bin/sherpa-onnx-offline
|
|
||||||
|
|
||||||
.github/scripts/test-offline-whisper.sh
|
|
||||||
|
|
||||||
- name: Test offline transducer
|
- name: Test offline transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
14
.github/workflows/windows-x86.yaml
vendored
14
.github/workflows/windows-x86.yaml
vendored
@@ -93,13 +93,13 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-online-paraformer.sh
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper for windows x86
|
# - name: Test offline Whisper for windows x86
|
||||||
shell: bash
|
# shell: bash
|
||||||
run: |
|
# run: |
|
||||||
export PATH=$PWD/build/bin/Release:$PATH
|
# export PATH=$PWD/build/bin/Release:$PATH
|
||||||
export EXE=sherpa-onnx-offline.exe
|
# export EXE=sherpa-onnx-offline.exe
|
||||||
|
#
|
||||||
.github/scripts/test-offline-whisper.sh
|
# .github/scripts/test-offline-whisper.sh
|
||||||
|
|
||||||
- name: Test offline CTC for windows x86
|
- name: Test offline CTC for windows x86
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
|
|||||||
|
|
||||||
Thanks to https://github.com/TadaoYamaoka
|
Thanks to https://github.com/TadaoYamaoka
|
||||||
for making the onnx export script public.
|
for making the onnx export script public.
|
||||||
|
|
||||||
|
Note that we have removed the 30 seconds constraint from whisper. You can
|
||||||
|
use any T <= 30.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -17,6 +20,7 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
|||||||
onnx.save(model, filename)
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
|
the mel spectrogram of the audio
|
||||||
|
"""
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
if False:
|
||||||
|
# This branch contains the original code
|
||||||
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
else:
|
||||||
|
# This branch contains the actual changes
|
||||||
|
assert (
|
||||||
|
x.shape[2] == self.positional_embedding.shape[1]
|
||||||
|
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
|
||||||
|
assert (
|
||||||
|
x.shape[1] == self.positional_embedding.shape[0]
|
||||||
|
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
|
||||||
|
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.ln_post(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
AudioEncoder.forward = modified_audio_encoder_forward
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoderTensorCache(nn.Module):
|
class AudioEncoderTensorCache(nn.Module):
|
||||||
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
|
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -279,6 +316,7 @@ def main():
|
|||||||
model = whisper.load_model(filename)
|
model = whisper.load_model(filename)
|
||||||
else:
|
else:
|
||||||
model = whisper.load_model(name)
|
model = whisper.load_model(name)
|
||||||
|
print(model.dims)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"number of model parameters: {name}",
|
f"number of model parameters: {name}",
|
||||||
@@ -311,19 +349,20 @@ def main():
|
|||||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||||
|
|
||||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||||
|
|
||||||
n_layer_cross_k, n_layer_cross_v = encoder(mel)
|
n_layer_cross_k, n_layer_cross_v = encoder(mel)
|
||||||
assert n_layer_cross_k.shape == (
|
assert n_layer_cross_k.shape == (
|
||||||
model.dims.n_text_layer,
|
model.dims.n_text_layer,
|
||||||
batch_size,
|
batch_size,
|
||||||
model.dims.n_audio_ctx,
|
model.dims.n_audio_ctx,
|
||||||
model.dims.n_text_state,
|
model.dims.n_text_state,
|
||||||
), n_layer_cross_k.shape
|
), (n_layer_cross_k.shape, model.dims)
|
||||||
assert n_layer_cross_v.shape == (
|
assert n_layer_cross_v.shape == (
|
||||||
model.dims.n_text_layer,
|
model.dims.n_text_layer,
|
||||||
batch_size,
|
batch_size,
|
||||||
model.dims.n_audio_ctx,
|
model.dims.n_audio_ctx,
|
||||||
model.dims.n_text_state,
|
model.dims.n_text_state,
|
||||||
), n_layer_cross_v.shape
|
), (n_layer_cross_v.shape, model.dims)
|
||||||
|
|
||||||
encoder_filename = f"{name}-encoder.onnx"
|
encoder_filename = f"{name}-encoder.onnx"
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
@@ -334,9 +373,9 @@ def main():
|
|||||||
input_names=["mel"],
|
input_names=["mel"],
|
||||||
output_names=["n_layer_cross_k", "n_layer_cross_v"],
|
output_names=["n_layer_cross_k", "n_layer_cross_v"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
|
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
|
||||||
"n_layer_cross_k": {1: "n_audio"},
|
"n_layer_cross_k": {1: "n_audio", 2: "T"},
|
||||||
"n_layer_cross_v": {1: "n_audio"},
|
"n_layer_cross_v": {1: "n_audio", 2: "T"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -461,8 +500,8 @@ def main():
|
|||||||
"tokens": {0: "n_audio", 1: "n_tokens"},
|
"tokens": {0: "n_audio", 1: "n_tokens"},
|
||||||
"in_n_layer_self_k_cache": {1: "n_audio"},
|
"in_n_layer_self_k_cache": {1: "n_audio"},
|
||||||
"in_n_layer_self_v_cache": {1: "n_audio"},
|
"in_n_layer_self_v_cache": {1: "n_audio"},
|
||||||
"n_layer_cross_k": {1: "n_audio"},
|
"n_layer_cross_k": {1: "n_audio", 2: "T"},
|
||||||
"n_layer_cross_v": {1: "n_audio"},
|
"n_layer_cross_v": {1: "n_audio", 2: "T"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor:
|
|||||||
log_spec = torch.clamp(features, min=1e-10).log10()
|
log_spec = torch.clamp(features, min=1e-10).log10()
|
||||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
mel = (log_spec + 4.0) / 4.0
|
mel = (log_spec + 4.0) / 4.0
|
||||||
|
# mel (T, 80)
|
||||||
|
|
||||||
|
# We pad 50 frames at the end so that it is able to detect eot
|
||||||
|
# You can use another value instead of 50.
|
||||||
|
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
|
||||||
|
# Note that if it throws for a multilingual model,
|
||||||
|
# please use a larger value, say 300
|
||||||
|
|
||||||
target = 3000
|
target = 3000
|
||||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
if mel.shape[0] > target:
|
||||||
|
mel = mel[:target]
|
||||||
|
|
||||||
|
# We don't need to pad it to 30 seconds now!
|
||||||
|
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||||
|
|
||||||
mel = mel.t().unsqueeze(0)
|
mel = mel.t().unsqueeze(0)
|
||||||
|
|
||||||
return mel
|
return mel
|
||||||
|
|||||||
@@ -115,7 +115,27 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||||
|
|
||||||
std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
|
// note that 50 is an experience value.
|
||||||
|
// see also ../../scripts/whisper/test.py
|
||||||
|
//
|
||||||
|
// You can replace 50 by other values, say, 100.
|
||||||
|
//
|
||||||
|
// Since we have removed the 30 seconds constraint, we need
|
||||||
|
// tail_padding_frames so that whisper is able to detect the eot token.
|
||||||
|
int32_t tail_padding_frames = 50;
|
||||||
|
if (model_->IsMultiLingual()) {
|
||||||
|
// 300 is an experience value. If it throws, please use a larger value.
|
||||||
|
tail_padding_frames = 300;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.model_config.whisper.tail_paddings > 0) {
|
||||||
|
tail_padding_frames = config_.model_config.whisper.tail_paddings;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t actual_frames =
|
||||||
|
std::min(num_frames + tail_padding_frames, max_num_frames);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
|
||||||
|
|
||||||
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||||
model_->Allocator(), shape.data(), shape.size());
|
model_->Allocator(), shape.data(), shape.size());
|
||||||
@@ -123,7 +143,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
std::copy(f.begin(), f.end(), p_mel);
|
std::copy(f.begin(), f.end(), p_mel);
|
||||||
|
|
||||||
memset(p_mel + f.size(), 0,
|
memset(p_mel + f.size(), 0,
|
||||||
(max_num_frames - num_frames) * feat_dim * sizeof(float));
|
(actual_frames - num_frames) * feat_dim * sizeof(float));
|
||||||
mel = Transpose12(model_->Allocator(), &mel);
|
mel = Transpose12(model_->Allocator(), &mel);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
|||||||
"Valid values: transcribe, translate. "
|
"Valid values: transcribe, translate. "
|
||||||
"Note that for non-multilingual models, it supports "
|
"Note that for non-multilingual models, it supports "
|
||||||
"only 'transcribe'");
|
"only 'transcribe'");
|
||||||
|
|
||||||
|
po->Register(
|
||||||
|
"whisper-tail-paddings", &tail_paddings,
|
||||||
|
"Suggest value: 50 for English models. 300 for multilingual models. "
|
||||||
|
"Since we have removed the 30-second constraint, we need to add some "
|
||||||
|
"tail padding frames "
|
||||||
|
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
|
||||||
|
"English models and 300 for multilingual models.");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OfflineWhisperModelConfig::Validate() const {
|
bool OfflineWhisperModelConfig::Validate() const {
|
||||||
@@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const {
|
|||||||
os << "encoder=\"" << encoder << "\", ";
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
os << "decoder=\"" << decoder << "\", ";
|
os << "decoder=\"" << decoder << "\", ";
|
||||||
os << "language=\"" << language << "\", ";
|
os << "language=\"" << language << "\", ";
|
||||||
os << "task=\"" << task << "\")";
|
os << "task=\"" << task << "\", ";
|
||||||
|
os << "tail_paddings=" << tail_paddings << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig {
|
|||||||
// Note: For non-multilingual models, it supports only "transcribe"
|
// Note: For non-multilingual models, it supports only "transcribe"
|
||||||
std::string task = "transcribe";
|
std::string task = "transcribe";
|
||||||
|
|
||||||
|
// Number of tail padding frames.
|
||||||
|
//
|
||||||
|
// Since we remove the 30-second constraint, we need to add some paddings
|
||||||
|
// at the end.
|
||||||
|
//
|
||||||
|
// Recommended values:
|
||||||
|
// - 50 for English models
|
||||||
|
// - 300 for multilingual models
|
||||||
|
int32_t tail_paddings = -1;
|
||||||
|
|
||||||
OfflineWhisperModelConfig() = default;
|
OfflineWhisperModelConfig() = default;
|
||||||
OfflineWhisperModelConfig(const std::string &encoder,
|
OfflineWhisperModelConfig(const std::string &encoder,
|
||||||
const std::string &decoder,
|
const std::string &decoder,
|
||||||
const std::string &language,
|
const std::string &language,
|
||||||
const std::string &task)
|
const std::string &task, int32_t tail_paddings)
|
||||||
: encoder(encoder), decoder(decoder), language(language), task(task) {}
|
: encoder(encoder),
|
||||||
|
decoder(decoder),
|
||||||
|
language(language),
|
||||||
|
task(task),
|
||||||
|
tail_paddings(tail_paddings) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) {
|
|||||||
using PyClass = OfflineWhisperModelConfig;
|
using PyClass = OfflineWhisperModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||||
.def(py::init<const std::string &, const std::string &,
|
.def(py::init<const std::string &, const std::string &,
|
||||||
const std::string &, const std::string &>(),
|
const std::string &, const std::string &, int32_t>(),
|
||||||
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
|
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
|
||||||
py::arg("task"))
|
py::arg("task"), py::arg("tail_paddings") = -1)
|
||||||
.def_readwrite("encoder", &PyClass::encoder)
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
.def_readwrite("decoder", &PyClass::decoder)
|
.def_readwrite("decoder", &PyClass::decoder)
|
||||||
.def_readwrite("language", &PyClass::language)
|
.def_readwrite("language", &PyClass::language)
|
||||||
.def_readwrite("task", &PyClass::task)
|
.def_readwrite("task", &PyClass::task)
|
||||||
|
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user