Perormance: Enable cuda graph for dp idle batch (#7269)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-24 08:34:13 +08:00
committed by GitHub
parent fa42e41962
commit ed0a0b692c
5 changed files with 51 additions and 50 deletions

View File

@@ -99,6 +99,8 @@ class EagleDraftInput:
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
accept_length_cpu=[],
)
def prepare_extend_after_decode(

View File

@@ -322,13 +322,11 @@ class EAGLEWorker(TpModelWorker):
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
need_forward, can_run_draft_extend_cuda_graph = (
self.check_forward_draft_extend_after_decode(batch)
)
if need_forward:
if self.check_forward_draft_extend_after_decode(batch):
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(
batch, can_run_draft_extend_cuda_graph
batch,
)
return (
logits_output,
@@ -344,7 +342,7 @@ class EAGLEWorker(TpModelWorker):
and batch.spec_info.verified_id.shape[0] > 0
)
if not self.server_args.enable_dp_attention:
return local_need_forward, True
return local_need_forward
global_need_forward = torch.tensor(
[
@@ -357,10 +355,7 @@ class EAGLEWorker(TpModelWorker):
)
global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0
can_run_draft_extend_cuda_graph = (
global_need_forward_cnt == get_tensor_model_parallel_world_size()
)
return need_forward, can_run_draft_extend_cuda_graph
return need_forward
def forward_target_extend(
self, batch: ScheduleBatch
@@ -816,15 +811,12 @@ class EAGLEWorker(TpModelWorker):
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
):
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata
@@ -836,14 +828,18 @@ class EAGLEWorker(TpModelWorker):
else:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
@@ -858,8 +854,7 @@ class EAGLEWorker(TpModelWorker):
# Run
can_cuda_graph = (
can_run_draft_extend_cuda_graph
and self.cuda_graph_runner_for_draft_extend
self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
)
if can_cuda_graph: