Remove the 30-second constraint from whisper. (#471)

This commit is contained in:
Fangjun Kuang
2023-12-07 17:47:08 +08:00
committed by GitHub
parent a7d69359c9
commit 3ae984f148
10 changed files with 178 additions and 78 deletions

View File

@@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
Thanks to https://github.com/TadaoYamaoka
for making the onnx export script public.
Note that we have removed the 30 seconds constraint from whisper. You can
use any T <= 30.
"""
import argparse
@@ -17,6 +20,7 @@ from typing import Any, Dict, Optional
import onnx
import torch
import torch.nn.functional as F
from onnxruntime.quantization import QuantType, quantize_dynamic
from torch import Tensor, nn
@@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
onnx.save(model, filename)
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
if False:
# This branch contains the original code
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
else:
# This branch contains the actual changes
assert (
x.shape[2] == self.positional_embedding.shape[1]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
assert (
x.shape[1] == self.positional_embedding.shape[0]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
AudioEncoder.forward = modified_audio_encoder_forward
class AudioEncoderTensorCache(nn.Module):
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
super().__init__()
@@ -279,6 +316,7 @@ def main():
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(model.dims)
print(
f"number of model parameters: {name}",
@@ -311,19 +349,20 @@ def main():
assert mel.shape == (batch_size, 80, 30 * 100)
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
n_layer_cross_k, n_layer_cross_v = encoder(mel)
assert n_layer_cross_k.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_k.shape
), (n_layer_cross_k.shape, model.dims)
assert n_layer_cross_v.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_v.shape
), (n_layer_cross_v.shape, model.dims)
encoder_filename = f"{name}-encoder.onnx"
torch.onnx.export(
@@ -334,9 +373,9 @@ def main():
input_names=["mel"],
output_names=["n_layer_cross_k", "n_layer_cross_v"],
dynamic_axes={
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)
@@ -461,8 +500,8 @@ def main():
"tokens": {0: "n_audio", 1: "n_tokens"},
"in_n_layer_self_k_cache": {1: "n_audio"},
"in_n_layer_self_v_cache": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)