Sampler cudagraph (#1253)

This commit is contained in:
Liangsheng Yin
2024-08-28 18:58:52 -07:00
committed by GitHub
parent 8153168c96
commit 381dd57bd6
29 changed files with 342 additions and 116 deletions

View File

@@ -200,16 +200,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy())
output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits
batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits
@torch.inference_mode()