Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch = self.verify(
|
||||
batch, spec_info
|
||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||
self.verify(batch, spec_info)
|
||||
)
|
||||
|
||||
# If it is None, it means all requests are finished
|
||||
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
can_run_cuda_graph,
|
||||
)
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
logits_output, next_token_ids, _ = (
|
||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||
else:
|
||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0
|
||||
return logits_output, next_token_ids, bid, 0, False
|
||||
|
||||
def forward_target_extend(
|
||||
self, batch: ScheduleBatch
|
||||
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
return logits_output, next_token_ids, model_worker_batch.bid
|
||||
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
logits_output, _, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
if batch.return_logprob:
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
return logits_output, res, model_worker_batch
|
||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||
|
||||
def add_logprob_values(
|
||||
self,
|
||||
@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
forward_batch.return_logprob = False
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Run
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
Reference in New Issue
Block a user