From 6d0364681c8b1abc132cc88f1bb0b7a8a352628f Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Wed, 15 Oct 2025 19:11:33 +0800 Subject: [PATCH] Fix 1-step draft model forward (#11653) Signed-off-by: Shangming Cai Co-authored-by: Liangsheng Yin --- python/sglang/srt/speculative/draft_utils.py | 10 +------ .../sglang/srt/speculative/eagle_info_v2.py | 2 +- python/sglang/srt/speculative/eagle_worker.py | 28 +++++++++++-------- .../sglang/srt/speculative/eagle_worker_v2.py | 5 +++- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index aab54cc62..4f3a4617e 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -33,15 +33,7 @@ class DraftBackendFactory: def create_decode_backend(self): if self.speculative_num_steps == 1: - - class DummyAttnBackend: - def __init__(self): - pass - - def init_forward_metadata(*args, **kwargs): - pass - - return DummyAttnBackend() + return None backend_map = { "flashinfer": self._create_flashinfer_decode_backend, diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index b068abd4e..bf450934e 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -276,7 +276,7 @@ class EagleVerifyInputV2Mixin: accept_length=accept_length, # mutable simulate_acc_len=SIMULATE_ACC_LEN, bs=bs, - spec_steps=self.draft_token_num, + spec_steps=self.spec_steps, ) # Include the bonus token diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index d152bf8fd..736f65dad 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -218,16 +218,17 @@ class EAGLEWorker(TpModelWorker): return # Capture draft - tic = time.perf_counter() - before_mem = get_available_gpu_memory(self.device, self.gpu_id) - logger.info( - f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" - ) - self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) - after_mem = get_available_gpu_memory(self.device, self.gpu_id) - logger.info( - f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." - ) + if self.speculative_num_steps > 1: + tic = time.perf_counter() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" + ) + self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." + ) # Capture extend if self.draft_extend_attn_backend: @@ -500,8 +501,11 @@ class EAGLEWorker(TpModelWorker): ) else: forward_batch.can_run_dp_cuda_graph = False - if not forward_batch.forward_mode.is_idle(): - # Initialize attention backend + if ( + not forward_batch.forward_mode.is_idle() + and self.speculative_num_steps > 1 + ): + # Skip attention backend init for idle mode or 1-step draft self.draft_attn_backend.init_forward_metadata(forward_batch) # Run forward steps parent_list, top_scores_index, draft_tokens = self.draft_forward( diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 3ab0784d6..1b67b0e96 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -97,7 +97,10 @@ class EAGLEWorkerV2(EAGLEWorker): forward_batch, ) else: - self.draft_attn_backend.init_forward_metadata(forward_batch) + if self.speculative_num_steps > 1: + # Skip attention backend init for 1-step draft, + # `draft_forward` only does sample in this case. + self.draft_attn_backend.init_forward_metadata(forward_batch) parent_list, top_scores_index, draft_tokens = self.draft_forward( forward_batch )