Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

This commit is contained in:
Liangsheng Yin
2024-09-13 20:27:53 -07:00
committed by GitHub
parent f3d32f888a
commit 70b6802982
32 changed files with 103 additions and 224 deletions

View File

@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
logits_output = model_runner.forward(batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
logits_output = model_runner.forward(batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits