refactor EAGLE 2 (#3269)
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -34,6 +35,7 @@ if is_flashinfer_available():
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import PosEncodingMode
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
@@ -53,10 +55,19 @@ class PrefillMetadata:
|
||||
extend_no_prefix: bool
|
||||
|
||||
|
||||
# Reuse this workspace buffer across all flashinfer wrappers
|
||||
global_workspace_buffer = None
|
||||
|
||||
|
||||
class FlashInferAttnBackend(AttentionBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Parse constants
|
||||
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
),
|
||||
)
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||
|
||||
# Allocate buffers
|
||||
self.workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device=model_runner.device,
|
||||
)
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
global_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
self.kv_indptr = [
|
||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = [
|
||||
torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
else:
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_indptr = [kv_indptr_buf]
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
@@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.prefill_wrappers_verify = []
|
||||
self.decode_wrappers = []
|
||||
for _ in range(self.num_wrappers):
|
||||
self.prefill_wrappers_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
if not skip_prefill:
|
||||
self.prefill_wrappers_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
)
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
self.decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
@@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# Create indices updater
|
||||
if not skip_prefill:
|
||||
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
||||
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
@@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
else:
|
||||
cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
||||
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||
]
|
||||
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
||||
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
||||
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
@@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.req_to_token,
|
||||
)
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
wrapper.end_forward()
|
||||
wrapper.begin_forward(
|
||||
@@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMultiStepDraftBackend:
|
||||
"""
|
||||
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
||||
draft decoding steps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
self.kv_indptr = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
max_bs + 1,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
FlashInferAttnBackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
)
|
||||
)
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
||||
|
||||
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
|
||||
num_seqs = forward_batch.batch_size
|
||||
bs = self.topk * num_seqs
|
||||
seq_lens_sum = forward_batch.seq_lens_sum
|
||||
self.generate_draft_decode_kv_indices[
|
||||
(self.speculative_num_steps, num_seqs, self.topk)
|
||||
](
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.seq_lens,
|
||||
self.cuda_graph_kv_indices,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
self.kv_indptr_stride,
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
]
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
def call_fn(i, forward_batch):
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
forward_batch.spec_info.kv_indices = (
|
||||
forward_batch.spec_info.kv_indices.clone()
|
||||
)
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.batch_size * self.topk,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
||||
forward_batch.batch_size
|
||||
][0]
|
||||
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
seq_lens_sum=-1,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
@@ -937,3 +1094,105 @@ def should_use_tensor_core(
|
||||
return gqa_group_size > 4
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def fast_decode_plan(
|
||||
self,
|
||||
indptr: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
pos_encoding_mode: str = "NONE",
|
||||
window_left: int = -1,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
data_type: Union[str, torch.dtype] = "float16",
|
||||
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
) -> None:
|
||||
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
||||
batch_size = len(last_page_len)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
if self.is_cuda_graph_enabled:
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
||||
" mismatches the batch size set during initialization {}".format(
|
||||
batch_size, self._fixed_batch_size
|
||||
)
|
||||
)
|
||||
if len(indices) > len(self._paged_kv_indices_buf):
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
else:
|
||||
self._paged_kv_indptr_buf = indptr
|
||||
self._paged_kv_indices_buf = indices
|
||||
self._paged_kv_last_page_len_buf = last_page_len
|
||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||
if not q_data_type:
|
||||
q_data_type = data_type
|
||||
if not hasattr(self, "empty_q_data"):
|
||||
self.empty_q_data = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, q_data_type)
|
||||
if isinstance(q_data_type, str)
|
||||
else q_data_type
|
||||
),
|
||||
)
|
||||
self.empty_kv_cache = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, data_type) if isinstance(data_type, str) else data_type
|
||||
),
|
||||
)
|
||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||
empty_q_data = self.empty_q_data
|
||||
empty_kv_cache = self.empty_kv_cache
|
||||
if self.use_tensor_cores:
|
||||
if not self.is_cuda_graph_enabled:
|
||||
# when not using cudagraph, we need to create the indptr buffer, otherwise
|
||||
# the buffer is already created during initialization
|
||||
self._qo_indptr_buf = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=indptr.device
|
||||
)
|
||||
self._wrapper.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._qo_indptr_buf,
|
||||
indptr,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
empty_q_data,
|
||||
)
|
||||
else:
|
||||
self._wrapper.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
indptr,
|
||||
self.last_page_len,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
PosEncodingMode[pos_encoding_mode].value,
|
||||
logits_soft_cap,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
)
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
Reference in New Issue
Block a user