refactor EAGLE 2 (#3269)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: merrymercy <lianminzheng@gmail.com>
Co-authored-by: Ying1123 <sqy1415@gmail.com>
This commit is contained in:
Yineng Zhang
2025-02-03 20:52:30 +08:00
committed by GitHub
parent 3c8ac78dc1
commit 013021b6a1
9 changed files with 1271 additions and 687 deletions

View File

@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
import torch
@@ -34,6 +35,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
class WrapperDispatch(Enum):
@@ -53,10 +55,19 @@ class PrefillMetadata:
extend_no_prefix: bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
),
)
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
assert not (
model_runner.sliding_window_size is not None
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
# Allocate buffers
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
if kv_indptr_buf is None:
self.kv_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else:
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
@@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_verify = []
self.decode_wrappers = []
for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
if not skip_prefill:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
@@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
)
# Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
@@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
)
def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
def init_forward_metadata_capture_cuda_graph(
self,
@@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1],
)
else:
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
wrapper.end_forward()
wrapper.begin_forward(
@@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill:
)
class FlashInferMultiStepDraftBackend:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.kv_indptr_stride = self.kv_indptr.shape[1]
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
self.cuda_graph_kv_indices,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
self.kv_indptr_stride,
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
forward_batch.batch_size
][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
@@ -937,3 +1094,105 @@ def should_use_tensor_core(
return gqa_group_size > 4
else:
return False
def fast_decode_plan(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
data_type: Union[str, torch.dtype] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, self._fixed_batch_size
)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type
if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type)
if isinstance(q_data_type, str)
else q_data_type
),
)
self.empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
if self.use_tensor_cores:
if not self.is_cuda_graph_enabled:
# when not using cudagraph, we need to create the indptr buffer, otherwise
# the buffer is already created during initialization
self._qo_indptr_buf = torch.arange(
batch_size + 1, dtype=torch.int32, device=indptr.device
)
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._qo_indptr_buf,
indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
empty_q_data,
)
else:
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
indptr,
self.last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
PosEncodingMode[pos_encoding_mode].value,
logits_soft_cap,
empty_q_data,
empty_kv_cache,
)
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta