fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)

This commit is contained in:
lucaelin
2025-07-06 12:24:06 +02:00
committed by GitHub
parent d70b789582
commit 5ebb71909b
2 changed files with 21 additions and 22 deletions

View File

@@ -263,16 +263,15 @@ def main():
decoder_input_ids.append(token2id["<|notimestamp|>"])
decoder_input_ids.append(token2id["<|nodiarize|>"])
decoder_input_ids.append(0)
decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
logits, decoder_mems_list = model.run_decoder(
np.array([decoder_input_ids], dtype=np.int32),
decoder_mems_list,
enc_states,
enc_masks,
)
for pos, decoder_input_id in enumerate(decoder_input_ids):
logits, decoder_mems_list = model.run_decoder(
np.array([[decoder_input_id,pos]], dtype=np.int32),
decoder_mems_list,
enc_states,
enc_masks,
)
tokens = [logits.argmax()]
print("decoder_input_ids", decoder_input_ids)
eos = token2id["<|endoftext|>"]