Support EAGLE draft extend CUDA graph (#6606)
Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
This commit is contained in:
@@ -26,6 +26,9 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
||||
EAGLEDraftExtendCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
EagleDraftInput,
|
||||
EagleVerifyInput,
|
||||
@@ -189,6 +192,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
FlashAttentionMultiStepBackend,
|
||||
)
|
||||
|
||||
@@ -197,7 +201,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.draft_extend_attn_backend = FlashAttentionBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
@@ -242,7 +249,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Capture extend
|
||||
if self.draft_extend_attn_backend:
|
||||
raise NotImplementedError()
|
||||
tic = time.perf_counter()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
|
||||
self
|
||||
)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
@property
|
||||
def draft_model_runner(self):
|
||||
@@ -656,6 +674,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
|
||||
)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_logprob = False
|
||||
@@ -665,7 +684,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Run
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
can_cuda_graph = (
|
||||
self.cuda_graph_runner_for_draft_extend
|
||||
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
||||
)
|
||||
if can_cuda_graph:
|
||||
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
|
||||
forward_batch
|
||||
)
|
||||
else:
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
Reference in New Issue
Block a user