Remove the 30-second constraint from whisper. (#471)
This commit is contained in:
@@ -8,6 +8,9 @@ https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
|
||||
|
||||
Thanks to https://github.com/TadaoYamaoka
|
||||
for making the onnx export script public.
|
||||
|
||||
Note that we have removed the 30 seconds constraint from whisper. You can
|
||||
use any T <= 30.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -17,6 +20,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
if False:
|
||||
# This branch contains the original code
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
else:
|
||||
# This branch contains the actual changes
|
||||
assert (
|
||||
x.shape[2] == self.positional_embedding.shape[1]
|
||||
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
|
||||
assert (
|
||||
x.shape[1] == self.positional_embedding.shape[0]
|
||||
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
|
||||
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
AudioEncoder.forward = modified_audio_encoder_forward
|
||||
|
||||
|
||||
class AudioEncoderTensorCache(nn.Module):
|
||||
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
|
||||
super().__init__()
|
||||
@@ -279,6 +316,7 @@ def main():
|
||||
model = whisper.load_model(filename)
|
||||
else:
|
||||
model = whisper.load_model(name)
|
||||
print(model.dims)
|
||||
|
||||
print(
|
||||
f"number of model parameters: {name}",
|
||||
@@ -311,19 +349,20 @@ def main():
|
||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||
|
||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||
|
||||
n_layer_cross_k, n_layer_cross_v = encoder(mel)
|
||||
assert n_layer_cross_k.shape == (
|
||||
model.dims.n_text_layer,
|
||||
batch_size,
|
||||
model.dims.n_audio_ctx,
|
||||
model.dims.n_text_state,
|
||||
), n_layer_cross_k.shape
|
||||
), (n_layer_cross_k.shape, model.dims)
|
||||
assert n_layer_cross_v.shape == (
|
||||
model.dims.n_text_layer,
|
||||
batch_size,
|
||||
model.dims.n_audio_ctx,
|
||||
model.dims.n_text_state,
|
||||
), n_layer_cross_v.shape
|
||||
), (n_layer_cross_v.shape, model.dims)
|
||||
|
||||
encoder_filename = f"{name}-encoder.onnx"
|
||||
torch.onnx.export(
|
||||
@@ -334,9 +373,9 @@ def main():
|
||||
input_names=["mel"],
|
||||
output_names=["n_layer_cross_k", "n_layer_cross_v"],
|
||||
dynamic_axes={
|
||||
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
|
||||
"n_layer_cross_k": {1: "n_audio"},
|
||||
"n_layer_cross_v": {1: "n_audio"},
|
||||
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
|
||||
"n_layer_cross_k": {1: "n_audio", 2: "T"},
|
||||
"n_layer_cross_v": {1: "n_audio", 2: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -461,8 +500,8 @@ def main():
|
||||
"tokens": {0: "n_audio", 1: "n_tokens"},
|
||||
"in_n_layer_self_k_cache": {1: "n_audio"},
|
||||
"in_n_layer_self_v_cache": {1: "n_audio"},
|
||||
"n_layer_cross_k": {1: "n_audio"},
|
||||
"n_layer_cross_v": {1: "n_audio"},
|
||||
"n_layer_cross_k": {1: "n_audio", 2: "T"},
|
||||
"n_layer_cross_v": {1: "n_audio", 2: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
mel = (log_spec + 4.0) / 4.0
|
||||
# mel (T, 80)
|
||||
|
||||
# We pad 50 frames at the end so that it is able to detect eot
|
||||
# You can use another value instead of 50.
|
||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
|
||||
# Note that if it throws for a multilingual model,
|
||||
# please use a larger value, say 300
|
||||
|
||||
target = 3000
|
||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||
if mel.shape[0] > target:
|
||||
mel = mel[:target]
|
||||
|
||||
# We don't need to pad it to 30 seconds now!
|
||||
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||
|
||||
mel = mel.t().unsqueeze(0)
|
||||
|
||||
return mel
|
||||
|
||||
Reference in New Issue
Block a user