Reduce whisper decoder file size with onnx export (#328)
This commit is contained in:
@@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module):
|
||||
|
||||
x = self.textDecoder.ln(x)
|
||||
|
||||
logits = (
|
||||
x
|
||||
@ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
if False:
|
||||
# x.shape (1, 3, 384)
|
||||
# weight.shape (51684, 384)
|
||||
|
||||
logits = (
|
||||
x
|
||||
@ torch.transpose(
|
||||
self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
|
||||
)
|
||||
).float()
|
||||
else:
|
||||
logits = (
|
||||
torch.matmul(
|
||||
self.textDecoder.token_embedding.weight.to(x.dtype),
|
||||
x.permute(0, 2, 1),
|
||||
)
|
||||
.permute(0, 2, 1)
|
||||
.float()
|
||||
)
|
||||
|
||||
return logits, n_layer_self_k_cache, n_layer_self_v_cache
|
||||
|
||||
@@ -246,6 +261,19 @@ def main():
|
||||
opset_version = 13
|
||||
|
||||
model = whisper.load_model(name)
|
||||
print(
|
||||
f"number of model parameters: {name}",
|
||||
sum(p.numel() for p in model.parameters()),
|
||||
)
|
||||
print(
|
||||
f"number of encoder parameters: {name}",
|
||||
sum(p.numel() for p in model.encoder.parameters()),
|
||||
)
|
||||
print(
|
||||
f"number of decoder parameters: {name}",
|
||||
sum(p.numel() for p in model.decoder.parameters()),
|
||||
)
|
||||
|
||||
convert_tokens(name=name, model=model)
|
||||
|
||||
# write tokens
|
||||
@@ -419,7 +447,7 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
if 'large' in args.model:
|
||||
if "large" in args.model:
|
||||
# it causes errors for large models, so skip it.
|
||||
return
|
||||
# Generate int8 quantization models
|
||||
|
||||
Reference in New Issue
Block a user