From 6c903611cae2b2edfb6da4a4b2321f2d1284b265 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 5 Jul 2025 02:18:16 -0700 Subject: [PATCH] Fix incorrect spec_num_draft_tokens in draft_extend (#7757) --- python/sglang/srt/layers/dp_attention.py | 8 ++++++++ python/sglang/srt/speculative/eagle_worker.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 1e5038436..ae4041956 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -237,6 +237,10 @@ def _dp_gather( assert ( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between global_tokens and local_tokens not allowed" + + # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1). + # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the + # actual size of the accepted tokens. if forward_batch.forward_mode.is_draft_extend(): shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) @@ -291,6 +295,10 @@ def dp_scatter( assert ( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between local_tokens and global_tokens not allowed" + + # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1). + # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the + # actual size of the accepted tokens. if forward_batch.forward_mode.is_draft_extend(): shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e78da174f..b6a6dace6 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -844,7 +844,7 @@ class EAGLEWorker(TpModelWorker): ) batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens + model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1 assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner