Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user