fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)
This commit is contained in:
@@ -263,16 +263,15 @@ def main():
|
||||
decoder_input_ids.append(token2id["<|notimestamp|>"])
|
||||
decoder_input_ids.append(token2id["<|nodiarize|>"])
|
||||
|
||||
decoder_input_ids.append(0)
|
||||
|
||||
decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
|
||||
|
||||
logits, decoder_mems_list = model.run_decoder(
|
||||
np.array([decoder_input_ids], dtype=np.int32),
|
||||
decoder_mems_list,
|
||||
enc_states,
|
||||
enc_masks,
|
||||
)
|
||||
for pos, decoder_input_id in enumerate(decoder_input_ids):
|
||||
logits, decoder_mems_list = model.run_decoder(
|
||||
np.array([[decoder_input_id,pos]], dtype=np.int32),
|
||||
decoder_mems_list,
|
||||
enc_states,
|
||||
enc_masks,
|
||||
)
|
||||
tokens = [logits.argmax()]
|
||||
print("decoder_input_ids", decoder_input_ids)
|
||||
eos = token2id["<|endoftext|>"]
|
||||
|
||||
Reference in New Issue
Block a user