fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)

This commit is contained in:
lucaelin
2025-07-06 12:24:06 +02:00
committed by GitHub
parent d70b789582
commit 5ebb71909b
2 changed files with 21 additions and 22 deletions

View File

@@ -197,12 +197,12 @@ def export_decoder(canary_model):
decoder = DecoderWrapper(canary_model)
decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
decoder_mems_list_0 = torch.zeros(1, 1, 1024)
decoder_mems_list_1 = torch.zeros(1, 1, 1024)
decoder_mems_list_2 = torch.zeros(1, 1, 1024)
decoder_mems_list_3 = torch.zeros(1, 1, 1024)
decoder_mems_list_4 = torch.zeros(1, 1, 1024)
decoder_mems_list_5 = torch.zeros(1, 1, 1024)
decoder_mems_list_0 = torch.zeros(1, 10, 1024)
decoder_mems_list_1 = torch.zeros(1, 10, 1024)
decoder_mems_list_2 = torch.zeros(1, 10, 1024)
decoder_mems_list_3 = torch.zeros(1, 10, 1024)
decoder_mems_list_4 = torch.zeros(1, 10, 1024)
decoder_mems_list_5 = torch.zeros(1, 10, 1024)
enc_states = torch.zeros(1, 1000, 1024)
enc_mask = torch.ones(1, 1000).bool()
@@ -221,7 +221,9 @@ def export_decoder(canary_model):
enc_mask,
),
"decoder.onnx",
opset_version=14,
dynamo=True,
opset_version=18,
external_data=False,
input_names=[
"decoder_input_ids",
"decoder_mems_list_0",
@@ -272,13 +274,11 @@ def main():
export_decoder(canary_model)
for m in ["encoder", "decoder"]:
if m == "encoder":
# we don't quantize the decoder with int8 since the accuracy drops
quantize_dynamic(
model_input=f"./{m}.onnx",
model_output=f"./{m}.int8.onnx",
weight_type=QuantType.QUInt8,
)
quantize_dynamic(
model_input=f"./{m}.onnx",
model_output=f"./{m}.int8.onnx",
weight_type=QuantType.QUInt8,
)
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")