Support EAGLE draft extend CUDA graph (#6606)

Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
This commit is contained in:
Ke Bao
2025-05-27 17:35:17 +08:00
committed by GitHub
parent a3d7f4b673
commit 631950280a
5 changed files with 406 additions and 5 deletions

View File

@@ -26,6 +26,9 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
@@ -189,6 +192,7 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
FlashAttentionMultiStepBackend,
)
@@ -197,7 +201,10 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.draft_extend_attn_backend = FlashAttentionBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
@@ -242,7 +249,18 @@ class EAGLEWorker(TpModelWorker):
# Capture extend
if self.draft_extend_attn_backend:
raise NotImplementedError()
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
self
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
@property
def draft_model_runner(self):
@@ -656,6 +674,7 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_logprob = False
@@ -665,7 +684,19 @@ class EAGLEWorker(TpModelWorker):
)
# Run
logits_output, _ = self.draft_model_runner.forward(forward_batch)
can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
)
if can_cuda_graph:
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
forward_batch
)
else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self._detect_nan_if_needed(logits_output)
self.capture_for_decode(logits_output, forward_batch.spec_info)