fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)
This commit is contained in:
@@ -197,12 +197,12 @@ def export_decoder(canary_model):
|
|||||||
decoder = DecoderWrapper(canary_model)
|
decoder = DecoderWrapper(canary_model)
|
||||||
decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
|
decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
|
||||||
|
|
||||||
decoder_mems_list_0 = torch.zeros(1, 1, 1024)
|
decoder_mems_list_0 = torch.zeros(1, 10, 1024)
|
||||||
decoder_mems_list_1 = torch.zeros(1, 1, 1024)
|
decoder_mems_list_1 = torch.zeros(1, 10, 1024)
|
||||||
decoder_mems_list_2 = torch.zeros(1, 1, 1024)
|
decoder_mems_list_2 = torch.zeros(1, 10, 1024)
|
||||||
decoder_mems_list_3 = torch.zeros(1, 1, 1024)
|
decoder_mems_list_3 = torch.zeros(1, 10, 1024)
|
||||||
decoder_mems_list_4 = torch.zeros(1, 1, 1024)
|
decoder_mems_list_4 = torch.zeros(1, 10, 1024)
|
||||||
decoder_mems_list_5 = torch.zeros(1, 1, 1024)
|
decoder_mems_list_5 = torch.zeros(1, 10, 1024)
|
||||||
|
|
||||||
enc_states = torch.zeros(1, 1000, 1024)
|
enc_states = torch.zeros(1, 1000, 1024)
|
||||||
enc_mask = torch.ones(1, 1000).bool()
|
enc_mask = torch.ones(1, 1000).bool()
|
||||||
@@ -221,7 +221,9 @@ def export_decoder(canary_model):
|
|||||||
enc_mask,
|
enc_mask,
|
||||||
),
|
),
|
||||||
"decoder.onnx",
|
"decoder.onnx",
|
||||||
opset_version=14,
|
dynamo=True,
|
||||||
|
opset_version=18,
|
||||||
|
external_data=False,
|
||||||
input_names=[
|
input_names=[
|
||||||
"decoder_input_ids",
|
"decoder_input_ids",
|
||||||
"decoder_mems_list_0",
|
"decoder_mems_list_0",
|
||||||
@@ -272,13 +274,11 @@ def main():
|
|||||||
export_decoder(canary_model)
|
export_decoder(canary_model)
|
||||||
|
|
||||||
for m in ["encoder", "decoder"]:
|
for m in ["encoder", "decoder"]:
|
||||||
if m == "encoder":
|
quantize_dynamic(
|
||||||
# we don't quantize the decoder with int8 since the accuracy drops
|
model_input=f"./{m}.onnx",
|
||||||
quantize_dynamic(
|
model_output=f"./{m}.int8.onnx",
|
||||||
model_input=f"./{m}.onnx",
|
weight_type=QuantType.QUInt8,
|
||||||
model_output=f"./{m}.int8.onnx",
|
)
|
||||||
weight_type=QuantType.QUInt8,
|
|
||||||
)
|
|
||||||
|
|
||||||
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
|
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
|
||||||
|
|
||||||
|
|||||||
@@ -263,16 +263,15 @@ def main():
|
|||||||
decoder_input_ids.append(token2id["<|notimestamp|>"])
|
decoder_input_ids.append(token2id["<|notimestamp|>"])
|
||||||
decoder_input_ids.append(token2id["<|nodiarize|>"])
|
decoder_input_ids.append(token2id["<|nodiarize|>"])
|
||||||
|
|
||||||
decoder_input_ids.append(0)
|
|
||||||
|
|
||||||
decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
|
decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
|
||||||
|
|
||||||
logits, decoder_mems_list = model.run_decoder(
|
for pos, decoder_input_id in enumerate(decoder_input_ids):
|
||||||
np.array([decoder_input_ids], dtype=np.int32),
|
logits, decoder_mems_list = model.run_decoder(
|
||||||
decoder_mems_list,
|
np.array([[decoder_input_id,pos]], dtype=np.int32),
|
||||||
enc_states,
|
decoder_mems_list,
|
||||||
enc_masks,
|
enc_states,
|
||||||
)
|
enc_masks,
|
||||||
|
)
|
||||||
tokens = [logits.argmax()]
|
tokens = [logits.argmax()]
|
||||||
print("decoder_input_ids", decoder_input_ids)
|
print("decoder_input_ids", decoder_input_ids)
|
||||||
eos = token2id["<|endoftext|>"]
|
eos = token2id["<|endoftext|>"]
|
||||||
|
|||||||
Reference in New Issue
Block a user