Reduce whisper decoder file size with onnx export (#328)
This commit is contained in:
@@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module):
|
|||||||
|
|
||||||
x = self.textDecoder.ln(x)
|
x = self.textDecoder.ln(x)
|
||||||
|
|
||||||
logits = (
|
if False:
|
||||||
x
|
# x.shape (1, 3, 384)
|
||||||
@ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
|
# weight.shape (51684, 384)
|
||||||
).float()
|
|
||||||
|
logits = (
|
||||||
|
x
|
||||||
|
@ torch.transpose(
|
||||||
|
self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
|
||||||
|
)
|
||||||
|
).float()
|
||||||
|
else:
|
||||||
|
logits = (
|
||||||
|
torch.matmul(
|
||||||
|
self.textDecoder.token_embedding.weight.to(x.dtype),
|
||||||
|
x.permute(0, 2, 1),
|
||||||
|
)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
.float()
|
||||||
|
)
|
||||||
|
|
||||||
return logits, n_layer_self_k_cache, n_layer_self_v_cache
|
return logits, n_layer_self_k_cache, n_layer_self_v_cache
|
||||||
|
|
||||||
@@ -246,6 +261,19 @@ def main():
|
|||||||
opset_version = 13
|
opset_version = 13
|
||||||
|
|
||||||
model = whisper.load_model(name)
|
model = whisper.load_model(name)
|
||||||
|
print(
|
||||||
|
f"number of model parameters: {name}",
|
||||||
|
sum(p.numel() for p in model.parameters()),
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"number of encoder parameters: {name}",
|
||||||
|
sum(p.numel() for p in model.encoder.parameters()),
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"number of decoder parameters: {name}",
|
||||||
|
sum(p.numel() for p in model.decoder.parameters()),
|
||||||
|
)
|
||||||
|
|
||||||
convert_tokens(name=name, model=model)
|
convert_tokens(name=name, model=model)
|
||||||
|
|
||||||
# write tokens
|
# write tokens
|
||||||
@@ -419,7 +447,7 @@ def main():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'large' in args.model:
|
if "large" in args.model:
|
||||||
# it causes errors for large models, so skip it.
|
# it causes errors for large models, so skip it.
|
||||||
return
|
return
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
|
|||||||
Reference in New Issue
Block a user