Beta spec-overlap for EAGLE (#11398)

Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-10-12 11:02:22 +08:00
committed by GitHub
parent 47c606d3dc
commit 20a6c0a63d
21 changed files with 1567 additions and 108 deletions

View File

@@ -55,6 +55,25 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers of verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return [None, None]
def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
"""
Update the buffers returned by get_verify_fill_after_draft_buffers if needed.
Here, we need to redo the computation of all metadata of the attention backend
that depends on tree mask and position buffers.
"""
raise NotImplementedError()
def forward(
self,
q: torch.Tensor,

View File

@@ -29,7 +29,6 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
get_int_env_var,

View File

@@ -162,6 +162,8 @@ class TritonAttnBackend(AttentionBackend):
# Initialize forward metadata
self.forward_metadata: ForwardMetadata = None
self.cuda_graph_custom_mask = None
def get_num_kv_splits(
self,
num_kv_splits: torch.Tensor,
@@ -755,6 +757,19 @@ class TritonAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1
def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return [self.cuda_graph_custom_mask, None]
def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass
def forward_extend(
self,
q: torch.Tensor,

View File

@@ -384,6 +384,7 @@ class LogitsProcessor(nn.Module):
if (
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
or logits_metadata.forward_mode.is_draft_extend_v2()
):
pruned_states = hidden_states
if aux_hidden_states is not None: