Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-05-12 00:17:33 -07:00
committed by GitHub
parent 7d3a3d4510
commit fba8eccd7e
27 changed files with 293 additions and 121 deletions

View File

@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
# If it is None, it means all requests are finished
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(model_worker_batch)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids
)
return logits_output, next_token_ids, bid, 0
return logits_output, next_token_ids, bid, 0, False
def forward_target_extend(
self, batch: ScheduleBatch
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
return logits_output, next_token_ids, model_worker_batch.bid
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
)
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch
return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values(
self,
@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch, self.draft_model_runner
)
forward_batch.return_logprob = False
logits_output = self.draft_model_runner.forward(forward_batch)
logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
)
# Run
logits_output = self.draft_model_runner.forward(forward_batch)
logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
self.capture_for_decode(logits_output, forward_batch.spec_info)