Remove the 30-second constraint from whisper. (#471)

This commit is contained in:
Fangjun Kuang
2023-12-07 17:47:08 +08:00
committed by GitHub
parent a7d69359c9
commit 3ae984f148
10 changed files with 178 additions and 78 deletions

View File

@@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
Thanks to https://github.com/TadaoYamaoka
for making the onnx export script public.
Note that we have removed the 30 seconds constraint from whisper. You can
use any T <= 30.
"""
import argparse
@@ -17,6 +20,7 @@ from typing import Any, Dict, Optional
import onnx
import torch
import torch.nn.functional as F
from onnxruntime.quantization import QuantType, quantize_dynamic
from torch import Tensor, nn
@@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
onnx.save(model, filename)
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
if False:
# This branch contains the original code
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
else:
# This branch contains the actual changes
assert (
x.shape[2] == self.positional_embedding.shape[1]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
assert (
x.shape[1] == self.positional_embedding.shape[0]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
AudioEncoder.forward = modified_audio_encoder_forward
class AudioEncoderTensorCache(nn.Module):
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
super().__init__()
@@ -279,6 +316,7 @@ def main():
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(model.dims)
print(
f"number of model parameters: {name}",
@@ -311,19 +349,20 @@ def main():
assert mel.shape == (batch_size, 80, 30 * 100)
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
n_layer_cross_k, n_layer_cross_v = encoder(mel)
assert n_layer_cross_k.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_k.shape
), (n_layer_cross_k.shape, model.dims)
assert n_layer_cross_v.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_v.shape
), (n_layer_cross_v.shape, model.dims)
encoder_filename = f"{name}-encoder.onnx"
torch.onnx.export(
@@ -334,9 +373,9 @@ def main():
input_names=["mel"],
output_names=["n_layer_cross_k", "n_layer_cross_v"],
dynamic_axes={
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)
@@ -461,8 +500,8 @@ def main():
"tokens": {0: "n_audio", 1: "n_tokens"},
"in_n_layer_self_k_cache": {1: "n_audio"},
"in_n_layer_self_v_cache": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)

View File

@@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor:
log_spec = torch.clamp(features, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
mel = (log_spec + 4.0) / 4.0
# mel (T, 80)
# We pad 50 frames at the end so that it is able to detect eot
# You can use another value instead of 50.
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
# Note that if it throws for a multilingual model,
# please use a larger value, say 300
target = 3000
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
if mel.shape[0] > target:
mel = mel[:target]
# We don't need to pad it to 30 seconds now!
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
mel = mel.t().unsqueeze(0)
return mel