[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -20,14 +20,15 @@ import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
@@ -36,6 +37,7 @@ if is_flashinfer_available():
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import PosEncodingMode
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = [
|
||||
@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_last_page_len = kv_last_page_len_buf
|
||||
|
||||
self.qo_indptr = [
|
||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = [
|
||||
torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if spec_info is None:
|
||||
bs = len(req_pool_indices)
|
||||
@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput)
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
bs = len(seq_lens)
|
||||
if spec_info is None:
|
||||
assert len(seq_lens) == len(req_pool_indices)
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||
spec_info, EagleVerifyInput
|
||||
)
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
|
||||
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||
# This is used to remove some host-to-device copy overhead.
|
||||
global global_override_indptr_cpu
|
||||
|
||||
|
||||
class FlashInferMultiStepDraftBackend:
|
||||
"""
|
||||
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
||||
@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||
self.kv_indptr = torch.zeros(
|
||||
(
|
||||
@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
|
||||
kv_last_page_len_buf=self.kv_last_page_len,
|
||||
)
|
||||
)
|
||||
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
|
||||
@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
|
||||
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
||||
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
||||
global global_override_indptr_cpu
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
]
|
||||
global_override_indptr_cpu = indptr_cpu_whole[i]
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
global_override_indptr_cpu = None
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
kv_indices = torch.zeros(
|
||||
(
|
||||
@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
kv_start = 0
|
||||
kv_end = 0
|
||||
if kv_start_idx:
|
||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
|
||||
def should_use_tensor_core(
|
||||
kv_cache_dtype: torch.dtype,
|
||||
num_attention_heads: int,
|
||||
@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
|
||||
if env_override is not None:
|
||||
return env_override.lower() == "true"
|
||||
|
||||
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||
try:
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Calculate GQA group size
|
||||
gqa_group_size = num_attention_heads // num_kv_heads
|
||||
|
||||
@@ -1118,12 +1128,18 @@ def fast_decode_plan(
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
**kwargs,
|
||||
non_blocking: bool = True,
|
||||
) -> None:
|
||||
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
||||
"""
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
||||
Modifications:
|
||||
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
||||
- Remove unnecessary host-to-device copy for the metadata buffers.
|
||||
"""
|
||||
batch_size = len(last_page_len)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
if self.is_cuda_graph_enabled:
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
@@ -1136,13 +1152,19 @@ def fast_decode_plan(
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
# Skip these copies
|
||||
# self._paged_kv_indptr_buf.copy_(indptr)
|
||||
# self._paged_kv_indices_buf[: len(indices)] = indices
|
||||
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
||||
else:
|
||||
self._paged_kv_indptr_buf = indptr
|
||||
self._paged_kv_indices_buf = indices
|
||||
self._paged_kv_last_page_len_buf = last_page_len
|
||||
|
||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||
if not q_data_type:
|
||||
q_data_type = data_type
|
||||
|
||||
if not hasattr(self, "empty_q_data"):
|
||||
self.empty_q_data = torch.empty(
|
||||
0,
|
||||
@@ -1159,6 +1181,7 @@ def fast_decode_plan(
|
||||
),
|
||||
)
|
||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||
|
||||
empty_q_data = self.empty_q_data
|
||||
empty_kv_cache = self.empty_kv_cache
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
Reference in New Issue
Block a user