Beta spec-overlap for EAGLE (#11398)
Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -55,6 +55,25 @@ class AttentionBackend(ABC):
|
||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_verify_buffers_to_fill_after_draft(self):
|
||||
"""
|
||||
Return buffers of verify attention kernels that needs to be filled after draft.
|
||||
|
||||
Typically, these are tree mask and position buffers.
|
||||
"""
|
||||
return [None, None]
|
||||
|
||||
def update_verify_buffers_to_fill_after_draft(
|
||||
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
|
||||
):
|
||||
"""
|
||||
Update the buffers returned by get_verify_fill_after_draft_buffers if needed.
|
||||
|
||||
Here, we need to redo the computation of all metadata of the attention backend
|
||||
that depends on tree mask and position buffers.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -29,7 +29,6 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import (
|
||||
get_int_env_var,
|
||||
|
||||
@@ -162,6 +162,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
# Initialize forward metadata
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
|
||||
self.cuda_graph_custom_mask = None
|
||||
|
||||
def get_num_kv_splits(
|
||||
self,
|
||||
num_kv_splits: torch.Tensor,
|
||||
@@ -755,6 +757,19 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def get_verify_buffers_to_fill_after_draft(self):
|
||||
"""
|
||||
Return buffers for verify attention kernels that needs to be filled after draft.
|
||||
|
||||
Typically, these are tree mask and position buffers.
|
||||
"""
|
||||
return [self.cuda_graph_custom_mask, None]
|
||||
|
||||
def update_verify_buffers_to_fill_after_draft(
|
||||
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
|
||||
):
|
||||
pass
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -384,6 +384,7 @@ class LogitsProcessor(nn.Module):
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
or logits_metadata.forward_mode.is_target_verify()
|
||||
or logits_metadata.forward_mode.is_draft_extend_v2()
|
||||
):
|
||||
pruned_states = hidden_states
|
||||
if aux_hidden_states is not None:
|
||||
|
||||
Reference in New Issue
Block a user