diff --git a/scripts/nemo/canary/export_onnx_180m_flash.py b/scripts/nemo/canary/export_onnx_180m_flash.py index 7585c18d..2767a01e 100755 --- a/scripts/nemo/canary/export_onnx_180m_flash.py +++ b/scripts/nemo/canary/export_onnx_180m_flash.py @@ -197,12 +197,12 @@ def export_decoder(canary_model): decoder = DecoderWrapper(canary_model) decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32) - decoder_mems_list_0 = torch.zeros(1, 1, 1024) - decoder_mems_list_1 = torch.zeros(1, 1, 1024) - decoder_mems_list_2 = torch.zeros(1, 1, 1024) - decoder_mems_list_3 = torch.zeros(1, 1, 1024) - decoder_mems_list_4 = torch.zeros(1, 1, 1024) - decoder_mems_list_5 = torch.zeros(1, 1, 1024) + decoder_mems_list_0 = torch.zeros(1, 10, 1024) + decoder_mems_list_1 = torch.zeros(1, 10, 1024) + decoder_mems_list_2 = torch.zeros(1, 10, 1024) + decoder_mems_list_3 = torch.zeros(1, 10, 1024) + decoder_mems_list_4 = torch.zeros(1, 10, 1024) + decoder_mems_list_5 = torch.zeros(1, 10, 1024) enc_states = torch.zeros(1, 1000, 1024) enc_mask = torch.ones(1, 1000).bool() @@ -221,7 +221,9 @@ def export_decoder(canary_model): enc_mask, ), "decoder.onnx", - opset_version=14, + dynamo=True, + opset_version=18, + external_data=False, input_names=[ "decoder_input_ids", "decoder_mems_list_0", @@ -272,13 +274,11 @@ def main(): export_decoder(canary_model) for m in ["encoder", "decoder"]: - if m == "encoder": - # we don't quantize the decoder with int8 since the accuracy drops - quantize_dynamic( - model_input=f"./{m}.onnx", - model_output=f"./{m}.int8.onnx", - weight_type=QuantType.QUInt8, - ) + quantize_dynamic( + model_input=f"./{m}.onnx", + model_output=f"./{m}.int8.onnx", + weight_type=QuantType.QUInt8, + ) export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") diff --git a/scripts/nemo/canary/test_180m_flash.py b/scripts/nemo/canary/test_180m_flash.py index cfa04250..690ed898 100755 --- a/scripts/nemo/canary/test_180m_flash.py +++ b/scripts/nemo/canary/test_180m_flash.py @@ -263,16 +263,15 @@ def main(): decoder_input_ids.append(token2id["<|notimestamp|>"]) decoder_input_ids.append(token2id["<|nodiarize|>"]) - decoder_input_ids.append(0) - decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] - logits, decoder_mems_list = model.run_decoder( - np.array([decoder_input_ids], dtype=np.int32), - decoder_mems_list, - enc_states, - enc_masks, - ) + for pos, decoder_input_id in enumerate(decoder_input_ids): + logits, decoder_mems_list = model.run_decoder( + np.array([[decoder_input_id,pos]], dtype=np.int32), + decoder_mems_list, + enc_states, + enc_masks, + ) tokens = [logits.argmax()] print("decoder_input_ids", decoder_input_ids) eos = token2id["<|endoftext|>"]