refactor EAGLE 2 (#3269)
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ def main():
|
|||||||
speculative_num_steps=3,
|
speculative_num_steps=3,
|
||||||
speculative_eagle_topk=4,
|
speculative_eagle_topk=4,
|
||||||
speculative_num_draft_tokens=16,
|
speculative_num_draft_tokens=16,
|
||||||
|
cuda_graph_max_bs=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -34,6 +35,7 @@ if is_flashinfer_available():
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
|
from flashinfer.decode import PosEncodingMode
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
@@ -53,10 +55,19 @@ class PrefillMetadata:
|
|||||||
extend_no_prefix: bool
|
extend_no_prefix: bool
|
||||||
|
|
||||||
|
|
||||||
|
# Reuse this workspace buffer across all flashinfer wrappers
|
||||||
|
global_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
class FlashInferAttnBackend(AttentionBackend):
|
class FlashInferAttnBackend(AttentionBackend):
|
||||||
"""Flashinfer attention kernels."""
|
"""Flashinfer attention kernels."""
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
skip_prefill: bool = False,
|
||||||
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.skip_prefill = skip_prefill
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||||
|
|
||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
self.workspace_buffer = torch.empty(
|
global global_workspace_buffer
|
||||||
global_config.flashinfer_workspace_size,
|
if global_workspace_buffer is None:
|
||||||
dtype=torch.uint8,
|
global_workspace_buffer = torch.empty(
|
||||||
device=model_runner.device,
|
global_config.flashinfer_workspace_size,
|
||||||
)
|
dtype=torch.uint8,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
self.workspace_buffer = global_workspace_buffer
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
self.kv_indptr = [
|
if kv_indptr_buf is None:
|
||||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
self.kv_indptr = [
|
||||||
for _ in range(self.num_wrappers)
|
torch.zeros(
|
||||||
]
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
for _ in range(self.num_wrappers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert self.num_wrappers == 1
|
||||||
|
self.kv_indptr = [kv_indptr_buf]
|
||||||
|
|
||||||
self.kv_last_page_len = torch.ones(
|
self.kv_last_page_len = torch.ones(
|
||||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
@@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.prefill_wrappers_verify = []
|
self.prefill_wrappers_verify = []
|
||||||
self.decode_wrappers = []
|
self.decode_wrappers = []
|
||||||
for _ in range(self.num_wrappers):
|
for _ in range(self.num_wrappers):
|
||||||
self.prefill_wrappers_paged.append(
|
if not skip_prefill:
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
self.prefill_wrappers_paged.append(
|
||||||
)
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
self.prefill_wrappers_verify.append(
|
self.workspace_buffer,
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
"NHD",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
self.prefill_wrappers_verify.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
|
)
|
||||||
self.decode_wrappers.append(
|
self.decode_wrappers.append(
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create indices updater
|
# Create indices updater
|
||||||
|
if not skip_prefill:
|
||||||
|
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||||
|
model_runner, self
|
||||||
|
)
|
||||||
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
||||||
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
|
||||||
model_runner, self
|
|
||||||
)
|
|
||||||
|
|
||||||
# Other metadata
|
# Other metadata
|
||||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||||
@@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(
|
||||||
cuda_graph_kv_indices = torch.zeros(
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||||
(max_bs * self.max_context_len,),
|
):
|
||||||
dtype=torch.int32,
|
if kv_indices_buf is None:
|
||||||
device="cuda",
|
cuda_graph_kv_indices = torch.zeros(
|
||||||
)
|
(max_bs * self.max_context_len,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuda_graph_kv_indices = kv_indices_buf
|
||||||
|
|
||||||
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
||||||
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.cuda_graph_custom_mask = torch.zeros(
|
if not self.skip_prefill:
|
||||||
(max_bs * self.max_context_len),
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
dtype=torch.uint8,
|
(max_bs * self.max_context_len),
|
||||||
device="cuda",
|
dtype=torch.uint8,
|
||||||
)
|
device="cuda",
|
||||||
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
)
|
||||||
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
||||||
|
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
req_pool_indices,
|
bs = kv_indptr.shape[0] - 1
|
||||||
paged_kernel_lens,
|
|
||||||
self.req_to_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapper.end_forward()
|
wrapper.end_forward()
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
@@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferMultiStepDraftBackend:
|
||||||
|
"""
|
||||||
|
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
||||||
|
draft decoding steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
topk: int,
|
||||||
|
speculative_num_steps: int,
|
||||||
|
):
|
||||||
|
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||||
|
|
||||||
|
self.topk = topk
|
||||||
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||||
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(
|
||||||
|
self.speculative_num_steps,
|
||||||
|
max_bs + 1,
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
self.attn_backends = []
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends.append(
|
||||||
|
FlashInferAttnBackend(
|
||||||
|
model_runner,
|
||||||
|
skip_prefill=True,
|
||||||
|
kv_indptr_buf=self.kv_indptr[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.max_context_len = self.attn_backends[0].max_context_len
|
||||||
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
|
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
||||||
|
|
||||||
|
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
|
||||||
|
num_seqs = forward_batch.batch_size
|
||||||
|
bs = self.topk * num_seqs
|
||||||
|
seq_lens_sum = forward_batch.seq_lens_sum
|
||||||
|
self.generate_draft_decode_kv_indices[
|
||||||
|
(self.speculative_num_steps, num_seqs, self.topk)
|
||||||
|
](
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
self.cuda_graph_kv_indices,
|
||||||
|
self.kv_indptr,
|
||||||
|
forward_batch.positions,
|
||||||
|
num_seqs,
|
||||||
|
self.topk,
|
||||||
|
self.pool_len,
|
||||||
|
self.kv_indptr_stride,
|
||||||
|
self.kv_indptr.shape[1],
|
||||||
|
triton.next_power_of_2(num_seqs),
|
||||||
|
triton.next_power_of_2(self.speculative_num_steps),
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
|
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
|
||||||
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
|
]
|
||||||
|
call_fn(i, forward_batch)
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
def call_fn(i, forward_batch):
|
||||||
|
forward_batch.spec_info.kv_indptr = (
|
||||||
|
forward_batch.spec_info.kv_indptr.clone()
|
||||||
|
)
|
||||||
|
forward_batch.spec_info.kv_indices = (
|
||||||
|
forward_batch.spec_info.kv_indices.clone()
|
||||||
|
)
|
||||||
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
self.common_template(forward_batch, call_fn)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
|
def call_fn(i, forward_batch):
|
||||||
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||||
|
forward_batch.batch_size,
|
||||||
|
forward_batch.batch_size * self.topk,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
encoder_lens=None,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
||||||
|
forward_batch.batch_size
|
||||||
|
][0]
|
||||||
|
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
||||||
|
|
||||||
|
self.common_template(forward_batch, call_fn)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
||||||
|
def call_fn(i, forward_batch):
|
||||||
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
|
forward_batch.batch_size,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
seq_lens_sum=-1,
|
||||||
|
encoder_lens=None,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.common_template(forward_batch, call_fn)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def create_flashinfer_kv_indices_triton(
|
def create_flashinfer_kv_indices_triton(
|
||||||
req_to_token_ptr, # [max_batch, max_context_len]
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
@@ -937,3 +1094,105 @@ def should_use_tensor_core(
|
|||||||
return gqa_group_size > 4
|
return gqa_group_size > 4
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def fast_decode_plan(
|
||||||
|
self,
|
||||||
|
indptr: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
last_page_len: torch.Tensor,
|
||||||
|
num_qo_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
page_size: int,
|
||||||
|
pos_encoding_mode: str = "NONE",
|
||||||
|
window_left: int = -1,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
data_type: Union[str, torch.dtype] = "float16",
|
||||||
|
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
sm_scale: Optional[float] = None,
|
||||||
|
rope_scale: Optional[float] = None,
|
||||||
|
rope_theta: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
||||||
|
batch_size = len(last_page_len)
|
||||||
|
if logits_soft_cap is None:
|
||||||
|
logits_soft_cap = 0.0
|
||||||
|
if self.is_cuda_graph_enabled:
|
||||||
|
if batch_size != self._fixed_batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
||||||
|
" mismatches the batch size set during initialization {}".format(
|
||||||
|
batch_size, self._fixed_batch_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(indices) > len(self._paged_kv_indices_buf):
|
||||||
|
raise ValueError(
|
||||||
|
"The size of indices should be less than or equal to the allocated buffer"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._paged_kv_indptr_buf = indptr
|
||||||
|
self._paged_kv_indices_buf = indices
|
||||||
|
self._paged_kv_last_page_len_buf = last_page_len
|
||||||
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||||
|
if not q_data_type:
|
||||||
|
q_data_type = data_type
|
||||||
|
if not hasattr(self, "empty_q_data"):
|
||||||
|
self.empty_q_data = torch.empty(
|
||||||
|
0,
|
||||||
|
dtype=(
|
||||||
|
getattr(torch, q_data_type)
|
||||||
|
if isinstance(q_data_type, str)
|
||||||
|
else q_data_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.empty_kv_cache = torch.empty(
|
||||||
|
0,
|
||||||
|
dtype=(
|
||||||
|
getattr(torch, data_type) if isinstance(data_type, str) else data_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||||
|
empty_q_data = self.empty_q_data
|
||||||
|
empty_kv_cache = self.empty_kv_cache
|
||||||
|
if self.use_tensor_cores:
|
||||||
|
if not self.is_cuda_graph_enabled:
|
||||||
|
# when not using cudagraph, we need to create the indptr buffer, otherwise
|
||||||
|
# the buffer is already created during initialization
|
||||||
|
self._qo_indptr_buf = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=indptr.device
|
||||||
|
)
|
||||||
|
self._wrapper.plan(
|
||||||
|
self._float_workspace_buffer,
|
||||||
|
self._int_workspace_buffer,
|
||||||
|
self._qo_indptr_buf,
|
||||||
|
indptr,
|
||||||
|
batch_size,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
page_size,
|
||||||
|
empty_q_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._wrapper.plan(
|
||||||
|
self._float_workspace_buffer,
|
||||||
|
self._int_workspace_buffer,
|
||||||
|
indptr,
|
||||||
|
self.last_page_len,
|
||||||
|
batch_size,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
page_size,
|
||||||
|
PosEncodingMode[pos_encoding_mode].value,
|
||||||
|
logits_soft_cap,
|
||||||
|
empty_q_data,
|
||||||
|
empty_kv_cache,
|
||||||
|
)
|
||||||
|
self._pos_encoding_mode = pos_encoding_mode
|
||||||
|
self._window_left = window_left
|
||||||
|
self._logits_soft_cap = logits_soft_cap
|
||||||
|
self._sm_scale = sm_scale
|
||||||
|
self._rope_scale = rope_scale
|
||||||
|
self._rope_theta = rope_theta
|
||||||
|
|||||||
@@ -103,69 +103,75 @@ def set_torch_compile_config():
|
|||||||
torch._dynamo.config.cache_size_limit = 1024
|
torch._dynamo.config.cache_size_limit = 1024
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||||
|
server_args = model_runner.server_args
|
||||||
|
capture_bs = server_args.cuda_graph_bs
|
||||||
|
if capture_bs is None:
|
||||||
|
if server_args.disable_cuda_graph_padding:
|
||||||
|
capture_bs = list(range(1, 33)) + [64, 128]
|
||||||
|
else:
|
||||||
|
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||||
|
# is very samll. We add more values here to make sure we capture the maximum bs.
|
||||||
|
capture_bs = list(
|
||||||
|
sorted(
|
||||||
|
set(
|
||||||
|
capture_bs
|
||||||
|
+ [model_runner.req_to_token_pool.size - 1]
|
||||||
|
+ [model_runner.req_to_token_pool.size]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
capture_bs = [
|
||||||
|
bs
|
||||||
|
for bs in capture_bs
|
||||||
|
if bs <= model_runner.req_to_token_pool.size
|
||||||
|
and bs <= server_args.cuda_graph_max_bs
|
||||||
|
]
|
||||||
|
compile_bs = (
|
||||||
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||||
|
if server_args.enable_torch_compile
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
return capture_bs, compile_bs
|
||||||
|
|
||||||
|
|
||||||
|
# Reuse this memory pool across all cuda graph runners.
|
||||||
|
global_graph_memory_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_graph_memory_pool():
|
||||||
|
return global_graph_memory_pool
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_graph_memory_pool(val):
|
||||||
|
global global_graph_memory_pool
|
||||||
|
global_graph_memory_pool = val
|
||||||
|
|
||||||
|
|
||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
def __init__(self, model_runner: "ModelRunner"):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.input_buffers = {}
|
|
||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.flashinfer_handlers = {}
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.graph_memory_pool = None
|
|
||||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = model_runner.server_args.tp_size
|
||||||
self.dp_size = self.model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
if self.capture_bs is None:
|
|
||||||
if model_runner.server_args.disable_cuda_graph_padding:
|
|
||||||
self.capture_bs = list(range(1, 33)) + [64, 128]
|
|
||||||
else:
|
|
||||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
|
||||||
|
|
||||||
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
|
||||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
|
||||||
# is very samll. We add more values here to make sure we capture the maximum bs.
|
|
||||||
self.capture_bs = list(
|
|
||||||
sorted(
|
|
||||||
set(
|
|
||||||
self.capture_bs
|
|
||||||
+ [model_runner.req_to_token_pool.size - 1]
|
|
||||||
+ [model_runner.req_to_token_pool.size]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.capture_bs = [
|
|
||||||
bs
|
|
||||||
for bs in self.capture_bs
|
|
||||||
if bs <= model_runner.req_to_token_pool.size
|
|
||||||
and bs <= model_runner.server_args.cuda_graph_max_bs
|
|
||||||
]
|
|
||||||
|
|
||||||
self.compile_bs = (
|
|
||||||
[
|
|
||||||
bs
|
|
||||||
for bs in self.capture_bs
|
|
||||||
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
|
||||||
]
|
|
||||||
if self.use_torch_compile
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
self.capture_forward_mode = ForwardMode.DECODE
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
self.num_tokens_per_bs = 1
|
self.num_tokens_per_bs = 1
|
||||||
if model_runner.spec_algorithm.is_eagle():
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
self.num_tokens_per_bs = (
|
raise RuntimeError("This should not happen")
|
||||||
self.model_runner.server_args.speculative_eagle_topk
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
self.num_tokens_per_bs = (
|
self.num_tokens_per_bs = (
|
||||||
@@ -182,10 +188,10 @@ class CudaGraphRunner:
|
|||||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||||
self.encoder_len_fill_value = 0
|
self.encoder_len_fill_value = 0
|
||||||
|
|
||||||
if self.use_torch_compile:
|
if self.enable_torch_compile:
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
# Common inputs
|
# Graph inputs
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
@@ -301,7 +307,7 @@ class CudaGraphRunner:
|
|||||||
stream = self.stream
|
stream = self.stream
|
||||||
num_tokens = bs * self.num_tokens_per_bs
|
num_tokens = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Common inputs
|
# Graph inputs
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
req_pool_indices = self.req_pool_indices[:bs]
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
seq_lens = self.seq_lens[:bs]
|
seq_lens = self.seq_lens[:bs]
|
||||||
@@ -320,7 +326,7 @@ class CudaGraphRunner:
|
|||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
|
|
||||||
spec_info = self.get_spec_info(num_tokens, positions)
|
spec_info = self.get_spec_info(num_tokens)
|
||||||
|
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
forward_mode=self.capture_forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
@@ -335,7 +341,6 @@ class CudaGraphRunner:
|
|||||||
seq_lens_sum=seq_lens.sum(),
|
seq_lens_sum=seq_lens.sum(),
|
||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * bs,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
global_num_tokens=global_num_tokens,
|
global_num_tokens=global_num_tokens,
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
@@ -375,13 +380,14 @@ class CudaGraphRunner:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.model_runner.tp_group.barrier()
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
global global_graph_memory_pool
|
||||||
|
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
||||||
out = run_once()
|
out = run_once()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.model_runner.tp_group.barrier()
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
self.graph_memory_pool = graph.pool()
|
global_graph_memory_pool = graph.pool()
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
@@ -439,35 +445,26 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
return logits_output
|
return logits_output
|
||||||
|
|
||||||
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
|
def get_spec_info(self, num_tokens: int):
|
||||||
spec_info = None
|
spec_info = None
|
||||||
if self.model_runner.spec_algorithm.is_eagle():
|
if self.model_runner.spec_algorithm.is_eagle():
|
||||||
from sglang.srt.speculative.eagle_utils import (
|
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||||
EAGLEDraftInput,
|
|
||||||
EagleVerifyInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
spec_info = EAGLEDraftInput()
|
raise RuntimeError("This should not happen.")
|
||||||
spec_info.load_server_args(self.model_runner.server_args)
|
|
||||||
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
|
||||||
spec_info.positions = positions
|
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
else:
|
else:
|
||||||
spec_info = EagleVerifyInput(
|
spec_info = EagleVerifyInput(
|
||||||
None,
|
draft_token=None,
|
||||||
None,
|
custom_mask=torch.zeros(
|
||||||
None,
|
(num_tokens * self.model_runner.model_config.context_len),
|
||||||
None,
|
dtype=torch.bool,
|
||||||
None,
|
device="cuda",
|
||||||
None,
|
),
|
||||||
self.model_runner.server_args.speculative_num_draft_tokens,
|
positions=None,
|
||||||
|
retrive_index=None,
|
||||||
|
retrive_cum_len=None,
|
||||||
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
spec_info.custom_mask = torch.zeros(
|
|
||||||
(num_tokens * self.model_runner.model_config.context_len),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
|
|
||||||
return spec_info
|
return spec_info
|
||||||
|
|||||||
@@ -197,64 +197,6 @@ class ForwardBatch:
|
|||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
mrope_positions: torch.Tensor = None
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
def compute_mrope_positions(
|
|
||||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
|
||||||
):
|
|
||||||
device = model_runner.device
|
|
||||||
hf_config = model_runner.model_config.hf_config
|
|
||||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
for i, _ in enumerate(mrope_positions_list):
|
|
||||||
mrope_position_delta = (
|
|
||||||
0
|
|
||||||
if batch.image_inputs[i] is None
|
|
||||||
else batch.image_inputs[i].mrope_position_delta
|
|
||||||
)
|
|
||||||
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
|
||||||
mrope_position_delta,
|
|
||||||
int(self.seq_lens[i]) - 1,
|
|
||||||
int(self.seq_lens[i]),
|
|
||||||
)
|
|
||||||
elif self.forward_mode.is_extend():
|
|
||||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
|
||||||
for i, image_inputs in enumerate(batch.image_inputs):
|
|
||||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
|
||||||
extend_start_loc_cpu[i],
|
|
||||||
batch.extend_seq_lens[i],
|
|
||||||
batch.extend_prefix_lens[i],
|
|
||||||
)
|
|
||||||
if image_inputs is None:
|
|
||||||
# text only
|
|
||||||
mrope_positions = [
|
|
||||||
[
|
|
||||||
pos
|
|
||||||
for pos in range(
|
|
||||||
extend_prefix_len, extend_prefix_len + extend_seq_len
|
|
||||||
)
|
|
||||||
]
|
|
||||||
] * 3
|
|
||||||
else:
|
|
||||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
|
||||||
mrope_positions, mrope_position_delta = (
|
|
||||||
MRotaryEmbedding.get_input_positions(
|
|
||||||
input_tokens=self.input_ids[
|
|
||||||
extend_start_loc : extend_start_loc + extend_seq_len
|
|
||||||
],
|
|
||||||
image_grid_thw=image_inputs.image_grid_thws,
|
|
||||||
vision_start_token_id=hf_config.vision_start_token_id,
|
|
||||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
|
||||||
context_len=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
|
||||||
mrope_positions_list[i] = mrope_positions
|
|
||||||
|
|
||||||
self.mrope_positions = torch.concat(
|
|
||||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -337,7 +279,7 @@ class ForwardBatch:
|
|||||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||||
|
|
||||||
if model_runner.model_is_mrope:
|
if model_runner.model_is_mrope:
|
||||||
ret.compute_mrope_positions(model_runner, batch)
|
ret._compute_mrope_positions(model_runner, batch)
|
||||||
|
|
||||||
# Init lora information
|
# Init lora information
|
||||||
if model_runner.server_args.lora_paths is not None:
|
if model_runner.server_args.lora_paths is not None:
|
||||||
@@ -345,6 +287,63 @@ class ForwardBatch:
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _compute_mrope_positions(
|
||||||
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
|
):
|
||||||
|
device = model_runner.device
|
||||||
|
hf_config = model_runner.model_config.hf_config
|
||||||
|
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||||
|
if self.forward_mode.is_decode():
|
||||||
|
for i, _ in enumerate(mrope_positions_list):
|
||||||
|
mrope_position_delta = (
|
||||||
|
0
|
||||||
|
if batch.image_inputs[i] is None
|
||||||
|
else batch.image_inputs[i].mrope_position_delta
|
||||||
|
)
|
||||||
|
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||||
|
mrope_position_delta,
|
||||||
|
int(self.seq_lens[i]) - 1,
|
||||||
|
int(self.seq_lens[i]),
|
||||||
|
)
|
||||||
|
elif self.forward_mode.is_extend():
|
||||||
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||||
|
for i, image_inputs in enumerate(batch.image_inputs):
|
||||||
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||||
|
extend_start_loc_cpu[i],
|
||||||
|
batch.extend_seq_lens[i],
|
||||||
|
batch.extend_prefix_lens[i],
|
||||||
|
)
|
||||||
|
if image_inputs is None:
|
||||||
|
# text only
|
||||||
|
mrope_positions = [
|
||||||
|
[
|
||||||
|
pos
|
||||||
|
for pos in range(
|
||||||
|
extend_prefix_len, extend_prefix_len + extend_seq_len
|
||||||
|
)
|
||||||
|
]
|
||||||
|
] * 3
|
||||||
|
else:
|
||||||
|
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||||
|
mrope_positions, mrope_position_delta = (
|
||||||
|
MRotaryEmbedding.get_input_positions(
|
||||||
|
input_tokens=self.input_ids[
|
||||||
|
extend_start_loc : extend_start_loc + extend_seq_len
|
||||||
|
],
|
||||||
|
image_grid_thw=image_inputs.image_grid_thws,
|
||||||
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||||
|
context_len=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||||
|
mrope_positions_list[i] = mrope_positions
|
||||||
|
self.mrope_positions = torch.concat(
|
||||||
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||||
|
|
||||||
|
|
||||||
def compute_position_triton(
|
def compute_position_triton(
|
||||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -714,8 +715,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|
||||||
|
|
||||||
self.cuda_graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
|
|
||||||
if not self.is_generation:
|
if not self.is_generation:
|
||||||
|
|||||||
@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
|
def build_tree_kernel(
|
||||||
|
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
|
||||||
|
):
|
||||||
bs = seq_lens.numel()
|
bs = seq_lens.numel()
|
||||||
device = parent_list.device
|
device = parent_list.device
|
||||||
tree_mask = torch.full(
|
tree_mask = torch.full(
|
||||||
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
|
(seq_lens_sum * draft_token + draft_token * draft_token * bs,),
|
||||||
True,
|
True,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|||||||
213
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
Normal file
213
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import bisect
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||||
|
CudaGraphRunner,
|
||||||
|
get_batch_sizes_to_capture,
|
||||||
|
get_global_graph_memory_pool,
|
||||||
|
set_global_graph_memory_pool,
|
||||||
|
set_torch_compile_config,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
|
|
||||||
|
|
||||||
|
class EAGLEDraftCudaGraphRunner:
|
||||||
|
def __init__(self, eagle_worker: EAGLEWorker):
|
||||||
|
# Parse args
|
||||||
|
self.eagle_worker = eagle_worker
|
||||||
|
self.model_runner = model_runner = eagle_worker.model_runner
|
||||||
|
self.graphs = {}
|
||||||
|
self.output_buffers = {}
|
||||||
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
|
self.tp_size = self.model_runner.tp_size
|
||||||
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||||
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
|
server_args = model_runner.server_args
|
||||||
|
|
||||||
|
assert self.disable_padding
|
||||||
|
|
||||||
|
# Batch sizes to capture
|
||||||
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
|
self.num_tokens_per_bs = server_args.speculative_eagle_topk
|
||||||
|
|
||||||
|
# Attention backend
|
||||||
|
self.max_bs = max(self.capture_bs)
|
||||||
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
|
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||||
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
||||||
|
0
|
||||||
|
].get_cuda_graph_seq_len_fill_value()
|
||||||
|
|
||||||
|
if self.enable_torch_compile:
|
||||||
|
set_torch_compile_config()
|
||||||
|
|
||||||
|
# Graph inputs
|
||||||
|
with torch.device("cuda"):
|
||||||
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
|
self.seq_lens = torch.full(
|
||||||
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
self.out_cache_loc = torch.zeros(
|
||||||
|
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
||||||
|
)
|
||||||
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
||||||
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
||||||
|
self.hidden_states = torch.zeros(
|
||||||
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||||
|
dtype=self.model_runner.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
try:
|
||||||
|
self.capture()
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Capture cuda graph failed: {e}\n"
|
||||||
|
"Possible solutions:\n"
|
||||||
|
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||||
|
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||||
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||||
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
is_bs_supported = (
|
||||||
|
forward_batch.batch_size in self.graphs
|
||||||
|
if self.disable_padding
|
||||||
|
else forward_batch.batch_size <= self.max_bs
|
||||||
|
)
|
||||||
|
return is_bs_supported
|
||||||
|
|
||||||
|
def capture(self):
|
||||||
|
CudaGraphRunner.capture(self)
|
||||||
|
|
||||||
|
def capture_one_batch_size(self, num_seqs: int, forward: Callable):
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
stream = self.stream
|
||||||
|
num_tokens = num_seqs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
# Graph inputs
|
||||||
|
req_pool_indices = self.req_pool_indices[:num_seqs]
|
||||||
|
seq_lens = self.seq_lens[:num_seqs]
|
||||||
|
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
||||||
|
positions = self.positions[:num_tokens]
|
||||||
|
topk_p = self.topk_p[:num_seqs]
|
||||||
|
topk_index = self.topk_index[:num_seqs]
|
||||||
|
hidden_states = self.hidden_states[:num_seqs]
|
||||||
|
|
||||||
|
spec_info = EagleDraftInput(
|
||||||
|
topk_p=topk_p,
|
||||||
|
topk_index=topk_index,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward batch
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
batch_size=num_seqs,
|
||||||
|
input_ids=None,
|
||||||
|
req_pool_indices=req_pool_indices,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
seq_lens_sum=seq_lens.sum(),
|
||||||
|
return_logprob=False,
|
||||||
|
positions=positions,
|
||||||
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
|
spec_info=spec_info,
|
||||||
|
capture_hidden_mode=(
|
||||||
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention backend
|
||||||
|
self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run and capture
|
||||||
|
def run_once():
|
||||||
|
# Backup two fileds, which will be modified in-place in `draft_forward`.
|
||||||
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||||
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||||
|
|
||||||
|
ret = self.eagle_worker.draft_forward(forward_batch)
|
||||||
|
|
||||||
|
forward_batch.out_cache_loc = output_cache_loc_backup
|
||||||
|
forward_batch.spec_info.hidden_states = hidden_states_backup
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
|
run_once()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
|
with torch.cuda.graph(
|
||||||
|
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||||
|
):
|
||||||
|
out = run_once()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
|
set_global_graph_memory_pool(graph.pool())
|
||||||
|
return graph, out
|
||||||
|
|
||||||
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
|
assert forward_batch.out_cache_loc is not None
|
||||||
|
raw_bs = forward_batch.batch_size
|
||||||
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
# Pad
|
||||||
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
|
bs = self.capture_bs[index]
|
||||||
|
if bs != raw_bs:
|
||||||
|
self.seq_lens.fill_(1)
|
||||||
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
|
# Common inputs
|
||||||
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
|
self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_(
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
)
|
||||||
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||||
|
self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
|
||||||
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||||
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||||
|
|
||||||
|
# Attention backend
|
||||||
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
self.graphs[bs].replay()
|
||||||
|
|
||||||
|
return self.output_buffers[bs]
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -9,13 +10,360 @@ import triton.language as tl
|
|||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class EagleDraftInput:
|
||||||
|
# The inputs for decode
|
||||||
|
# shape: (b, topk)
|
||||||
|
topk_p: torch.Tensor = None
|
||||||
|
topk_index: torch.Tensor = None
|
||||||
|
# shape: (b, hidden_size)
|
||||||
|
hidden_states: torch.Tensor = None
|
||||||
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
|
||||||
|
|
||||||
|
# Inputs for extend
|
||||||
|
# shape: (b,)
|
||||||
|
verified_id: torch.Tensor = None
|
||||||
|
accept_length: torch.Tensor = None
|
||||||
|
accept_length_cpu: List[int] = None
|
||||||
|
|
||||||
|
# Inputs for the attention backends
|
||||||
|
# shape: (b + 1,)
|
||||||
|
kv_indptr: torch.Tensor = None
|
||||||
|
kv_indices: torch.Tensor = None
|
||||||
|
|
||||||
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||||
|
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
||||||
|
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||||
|
batch.out_cache_loc = out_cache_loc
|
||||||
|
|
||||||
|
pt = 0
|
||||||
|
for i, req in enumerate(batch.reqs):
|
||||||
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
|
|
||||||
|
if pre_len > 0:
|
||||||
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
:pre_len
|
||||||
|
] = req.prefix_indices
|
||||||
|
|
||||||
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||||
|
out_cache_loc[pt : pt + req.extend_input_len]
|
||||||
|
)
|
||||||
|
|
||||||
|
pt += req.extend_input_len
|
||||||
|
|
||||||
|
# TODO: support batching inputs
|
||||||
|
assert len(batch.extend_lens) == 1
|
||||||
|
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
||||||
|
|
||||||
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
||||||
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||||
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||||
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||||
|
seq_lens_cpu = batch.seq_lens.tolist()
|
||||||
|
|
||||||
|
pt = 0
|
||||||
|
i = 0
|
||||||
|
for req in batch.reqs:
|
||||||
|
if req.finished():
|
||||||
|
continue
|
||||||
|
# assert seq_len - pre_len == req.extend_input_len
|
||||||
|
input_len = batch.extend_lens[i]
|
||||||
|
seq_len = seq_lens_cpu[i]
|
||||||
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
seq_len - input_len : seq_len
|
||||||
|
] = batch.out_cache_loc[pt : pt + input_len]
|
||||||
|
pt += input_len
|
||||||
|
i += 1
|
||||||
|
assert pt == batch.out_cache_loc.shape[0]
|
||||||
|
|
||||||
|
self.positions = torch.empty_like(self.verified_id)
|
||||||
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
||||||
|
self.accept_length.add_(1)
|
||||||
|
|
||||||
|
create_extend_spec_info[(self.accept_length.numel(),)](
|
||||||
|
self.verified_id,
|
||||||
|
batch.seq_lens,
|
||||||
|
self.accept_length,
|
||||||
|
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
||||||
|
self.positions,
|
||||||
|
new_verified_id,
|
||||||
|
triton.next_power_of_2(speculative_num_steps + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.seq_lens_sum = sum(seq_lens_cpu)
|
||||||
|
batch.input_ids = self.verified_id
|
||||||
|
self.verified_id = new_verified_id
|
||||||
|
|
||||||
|
def generate_attn_arg_prefill(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
bs = self.accept_length.numel()
|
||||||
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||||
|
|
||||||
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
cum_kv_seq_len,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
req_to_token.size(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||||
|
|
||||||
|
def filter_batch(self, new_indices: torch.Tensor):
|
||||||
|
self.topk_p = self.topk_p[: len(new_indices)]
|
||||||
|
self.topk_index = self.topk_index[: len(new_indices)]
|
||||||
|
self.hidden_states = self.hidden_states[: len(new_indices)]
|
||||||
|
self.verified_id = self.verified_id[: len(new_indices)]
|
||||||
|
|
||||||
|
def merge_batch(self, spec_info: EagleDraftInput):
|
||||||
|
if self.hidden_states is None:
|
||||||
|
self.hidden_states = spec_info.hidden_states
|
||||||
|
self.verified_id = spec_info.verified_id
|
||||||
|
self.topk_p = spec_info.topk_p
|
||||||
|
self.topk_index = spec_info.topk_index
|
||||||
|
return
|
||||||
|
if spec_info.hidden_states is None:
|
||||||
|
return
|
||||||
|
self.hidden_states = torch.cat(
|
||||||
|
[self.hidden_states, spec_info.hidden_states], axis=0
|
||||||
|
)
|
||||||
|
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
||||||
|
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
|
||||||
|
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class EagleVerifyInput:
|
||||||
|
draft_token: torch.Tensor
|
||||||
|
custom_mask: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
retrive_index: torch.Tensor
|
||||||
|
retrive_cum_len: torch.Tensor
|
||||||
|
draft_token_num: int
|
||||||
|
capture_hidden_mode: CaptureHiddenMode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
verified_id: torch.Tensor,
|
||||||
|
score_list: List[torch.Tensor],
|
||||||
|
token_list: List[torch.Tensor],
|
||||||
|
parents_list: List[torch.Tensor],
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
topk: int,
|
||||||
|
spec_steps: int,
|
||||||
|
num_verify_token: int,
|
||||||
|
):
|
||||||
|
score_list = torch.cat(score_list, dim=1).flatten(
|
||||||
|
1
|
||||||
|
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
||||||
|
ss_token_list = torch.cat(
|
||||||
|
token_list, dim=1
|
||||||
|
) # b, (self.topk + (num_steps-1) * self.topk)
|
||||||
|
top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
|
||||||
|
top_scores_index = top_scores.indices
|
||||||
|
top_scores_index = torch.sort(top_scores_index).values
|
||||||
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||||
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
|
||||||
|
parent_list = torch.cat(parents_list[:-1], dim=1)
|
||||||
|
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
|
||||||
|
parent_list,
|
||||||
|
top_scores_index,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
topk,
|
||||||
|
spec_steps,
|
||||||
|
num_verify_token,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
draft_tokens.flatten(),
|
||||||
|
tree_mask,
|
||||||
|
position,
|
||||||
|
retrive_index,
|
||||||
|
retrive_cum_len,
|
||||||
|
num_verify_token,
|
||||||
|
CaptureHiddenMode.FULL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||||
|
batch.input_ids = self.draft_token
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||||
|
bs = batch.batch_size()
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens + self.draft_token_num,
|
||||||
|
batch.out_cache_loc,
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_attn_arg_prefill(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
batch_size = len(req_pool_indices)
|
||||||
|
qo_indptr = torch.arange(
|
||||||
|
0,
|
||||||
|
(1 + batch_size) * self.draft_token_num,
|
||||||
|
step=self.draft_token_num,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
cum_kv_seq_len = torch.zeros(
|
||||||
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
||||||
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
|
||||||
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
cum_kv_seq_len,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
req_to_token.size(1),
|
||||||
|
)
|
||||||
|
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
||||||
|
|
||||||
|
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||||
|
predict = torch.cat(
|
||||||
|
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
|
||||||
|
)
|
||||||
|
draft_token = torch.cat(
|
||||||
|
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
target_predict = predict[self.retrive_index]
|
||||||
|
candidates = draft_token[self.retrive_index]
|
||||||
|
# logits = logits_output.next_token_logits[self.retrive_index]
|
||||||
|
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
|
||||||
|
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
||||||
|
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
||||||
|
bs = self.retrive_cum_len.numel() - 1
|
||||||
|
|
||||||
|
max_draft_len = self.retrive_index.shape[-1]
|
||||||
|
accept_index = torch.full(
|
||||||
|
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
|
||||||
|
)
|
||||||
|
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
|
||||||
|
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
|
||||||
|
eagle_verify_retrive[(bs,)](
|
||||||
|
self.retrive_index.contiguous(),
|
||||||
|
accept_mask.contiguous(),
|
||||||
|
self.retrive_cum_len,
|
||||||
|
accept_index,
|
||||||
|
accept_length,
|
||||||
|
extract_index,
|
||||||
|
max_draft_len,
|
||||||
|
self.draft_token_num,
|
||||||
|
triton.next_power_of_2(max_draft_len),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_accept_index = []
|
||||||
|
unfinished_index = []
|
||||||
|
finished_extend_len = {} # {rid:accept_length + 1}
|
||||||
|
accept_index_cpu = accept_index.tolist()
|
||||||
|
predict_cpu = predict.tolist()
|
||||||
|
has_finished = False
|
||||||
|
|
||||||
|
# iterate every accepted token and check if req has finished after append the token
|
||||||
|
# should be checked BEFORE free kv cache slots
|
||||||
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||||
|
new_accept_index_ = []
|
||||||
|
for j, idx in enumerate(accept_index_row):
|
||||||
|
if idx == -1:
|
||||||
|
break
|
||||||
|
id = predict_cpu[idx]
|
||||||
|
# if not found_finished:
|
||||||
|
req.output_ids.append(id)
|
||||||
|
finished_extend_len[req.rid] = j + 1
|
||||||
|
req.check_finished()
|
||||||
|
if req.finished():
|
||||||
|
has_finished = True
|
||||||
|
# set all tokens after finished token to -1 and break
|
||||||
|
accept_index[i, j + 1 :] = -1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
new_accept_index_.append(idx)
|
||||||
|
if not req.finished():
|
||||||
|
new_accept_index.extend(new_accept_index_)
|
||||||
|
unfinished_index.append(i)
|
||||||
|
req.spec_verify_ct += 1
|
||||||
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
|
||||||
|
accept_index = accept_index[accept_index != -1]
|
||||||
|
accept_length_cpu = accept_length.tolist()
|
||||||
|
verified_id = predict[accept_index]
|
||||||
|
|
||||||
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
|
evict_mask[accept_index] = False
|
||||||
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||||
|
batch.token_to_kv_pool.free(mem_need_free_idx)
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens + accept_length + 1,
|
||||||
|
batch.out_cache_loc[accept_index],
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
|
|
||||||
|
draft_input = EagleDraftInput()
|
||||||
|
if len(new_accept_index) > 0:
|
||||||
|
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
||||||
|
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
||||||
|
draft_input.verified_id = predict[new_accept_index]
|
||||||
|
draft_input.accept_length = accept_length[unfinished_index]
|
||||||
|
draft_input.accept_length_cpu = [
|
||||||
|
accept_length_cpu[i] for i in unfinished_index
|
||||||
|
]
|
||||||
|
if has_finished:
|
||||||
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||||
|
else:
|
||||||
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||||
|
|
||||||
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||||
|
return (
|
||||||
|
draft_input,
|
||||||
|
logits_output,
|
||||||
|
verified_id,
|
||||||
|
finished_extend_len,
|
||||||
|
accept_length_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -136,21 +484,57 @@ def assign_req_to_token_pool(
|
|||||||
load_offset += BLOCK_SIZE
|
load_offset += BLOCK_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def assign_draft_cache_locs(
|
||||||
|
req_pool_indices,
|
||||||
|
req_to_token,
|
||||||
|
seq_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
pool_len: tl.constexpr,
|
||||||
|
topk: tl.constexpr,
|
||||||
|
speculative_num_steps: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 32
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
kv_start = tl.load(seq_lens + pid)
|
||||||
|
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
||||||
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||||
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
|
||||||
|
load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = save_offset < kv_end
|
||||||
|
data = tl.load(out_cache_ptr + load_offset, mask=mask)
|
||||||
|
tl.store(token_pool + save_offset, data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def generate_draft_decode_kv_indices(
|
def generate_draft_decode_kv_indices(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
iters: tl.constexpr,
|
kv_indptr,
|
||||||
|
positions,
|
||||||
|
num_seqs: tl.constexpr,
|
||||||
topk: tl.constexpr,
|
topk: tl.constexpr,
|
||||||
pool_len: tl.constexpr,
|
pool_len: tl.constexpr,
|
||||||
|
kv_indices_stride: tl.constexpr,
|
||||||
|
kv_indptr_stride: tl.constexpr,
|
||||||
bs_upper: tl.constexpr,
|
bs_upper: tl.constexpr,
|
||||||
iter_upper: tl.constexpr,
|
iter_upper: tl.constexpr,
|
||||||
|
num_tokens_upper: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 128
|
BLOCK_SIZE: tl.constexpr = 128
|
||||||
bid = tl.program_id(axis=0)
|
iters = tl.program_id(axis=0)
|
||||||
topk_id = tl.program_id(axis=1)
|
bid = tl.program_id(axis=1)
|
||||||
|
topk_id = tl.program_id(axis=2)
|
||||||
|
|
||||||
|
kv_indices += kv_indices_stride * iters
|
||||||
|
kv_indptr += kv_indptr_stride * iters
|
||||||
|
iters += 1
|
||||||
|
|
||||||
load_offset = tl.arange(0, bs_upper)
|
load_offset = tl.arange(0, bs_upper)
|
||||||
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
|
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
|
||||||
@@ -176,473 +560,73 @@ def generate_draft_decode_kv_indices(
|
|||||||
)
|
)
|
||||||
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
||||||
|
|
||||||
|
# Update kv_indptr
|
||||||
|
bs_offset = tl.arange(0, num_tokens_upper)
|
||||||
|
|
||||||
class EAGLEDraftInput(SpecInfo):
|
zid = bid * topk + topk_id
|
||||||
def __init__(self):
|
if zid == 0:
|
||||||
self.prev_mode = ForwardMode.DECODE
|
zid = num_seqs * topk
|
||||||
|
positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
|
||||||
|
base = tl.sum(positions)
|
||||||
|
tl.store(kv_indptr + zid, base + zid * iters)
|
||||||
|
|
||||||
self.scores: torch.Tensor = None
|
|
||||||
self.score_list: List[torch.Tensor] = []
|
|
||||||
self.token_list: List[torch.Tensor] = []
|
|
||||||
self.origin_score_list: List[torch.Tensor] = [] # used for sampling
|
|
||||||
self.parents_list: List[torch.Tensor] = []
|
|
||||||
self.cache_list: List[torch.Tenor] = []
|
|
||||||
self.iter = 0
|
|
||||||
|
|
||||||
# shape: (b, hidden_size)
|
@torch.compile
|
||||||
self.hidden_states: torch.Tensor = None
|
def select_top_k_tokens(
|
||||||
# shape: (b,)
|
i: int,
|
||||||
self.verified_id: torch.Tensor = None
|
topk_p: torch.Tensor,
|
||||||
# shape: (b, vocab_size)
|
topk_index: torch.Tensor,
|
||||||
self.sample_output: torch.Tensor = None
|
hidden_states: torch.Tensor,
|
||||||
|
scores: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
):
|
||||||
|
if i == 0:
|
||||||
|
# The first step after extend
|
||||||
|
input_ids = topk_index.flatten()
|
||||||
|
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
|
||||||
|
scores = topk_p # shape: (b, topk)
|
||||||
|
|
||||||
self.positions: torch.Tensor = None
|
tree_info = (
|
||||||
self.accept_length: torch.Tensor = None
|
topk_p.unsqueeze(1), # shape: (b, 1, topk)
|
||||||
self.accept_length_cpu: List[int] = None
|
topk_index, # shape: (b, topk)
|
||||||
|
torch.arange(-1, topk, dtype=torch.long, device="cuda")
|
||||||
def load_server_args(self, server_args: ServerArgs):
|
.unsqueeze(0)
|
||||||
self.topk: int = server_args.speculative_eagle_topk
|
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
||||||
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
|
||||||
self.spec_steps = server_args.speculative_num_steps
|
|
||||||
|
|
||||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
|
||||||
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
|
||||||
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
|
||||||
batch.out_cache_loc = out_cache_loc
|
|
||||||
|
|
||||||
pt = 0
|
|
||||||
for i, req in enumerate(batch.reqs):
|
|
||||||
req.req_pool_idx = req_pool_indices[i]
|
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
|
||||||
assert seq_len - pre_len == req.extend_input_len
|
|
||||||
|
|
||||||
if pre_len > 0:
|
|
||||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
:pre_len
|
|
||||||
] = req.prefix_indices
|
|
||||||
|
|
||||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
|
||||||
out_cache_loc[pt : pt + req.extend_input_len]
|
|
||||||
)
|
|
||||||
|
|
||||||
pt += req.extend_input_len
|
|
||||||
|
|
||||||
# TODO: support batching inputs
|
|
||||||
assert len(batch.extend_lens) == 1
|
|
||||||
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
|
||||||
|
|
||||||
def filter_batch(
|
|
||||||
self,
|
|
||||||
new_indices: torch.Tensor,
|
|
||||||
):
|
|
||||||
self.sample_output = self.sample_output[: len(new_indices)]
|
|
||||||
self.hidden_states = self.hidden_states[: len(new_indices)]
|
|
||||||
self.verified_id = self.verified_id[: len(new_indices)]
|
|
||||||
|
|
||||||
def prepare_for_decode(self, batch: ScheduleBatch):
|
|
||||||
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
|
||||||
top = torch.topk(prob, self.topk, dim=-1)
|
|
||||||
topk_index, topk_p = (
|
|
||||||
top.indices,
|
|
||||||
top.values,
|
|
||||||
) # shape: (b * top_k, top_k) or (b, top_k)
|
|
||||||
|
|
||||||
if self.prev_mode.is_decode():
|
|
||||||
scores = torch.mul(
|
|
||||||
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
|
|
||||||
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
|
||||||
topk_cs = torch.topk(
|
|
||||||
scores.flatten(start_dim=1), self.topk, dim=-1
|
|
||||||
) # (b, topk)
|
|
||||||
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
|
|
||||||
|
|
||||||
selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
|
|
||||||
0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
|
|
||||||
).repeat_interleave(self.topk)
|
|
||||||
|
|
||||||
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
|
|
||||||
selected_input_index, :
|
|
||||||
]
|
|
||||||
|
|
||||||
topk_index = topk_index.reshape(-1, self.topk**2)
|
|
||||||
batch.input_ids = torch.gather(
|
|
||||||
topk_index, index=topk_cs_index, dim=1
|
|
||||||
).flatten()
|
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
|
||||||
|
|
||||||
self.scores = topk_cs_p
|
|
||||||
self.score_list.append(scores) # (b, topk, topk)
|
|
||||||
self.token_list.append(topk_index) # (b, topk * topk)
|
|
||||||
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
|
|
||||||
self.parents_list.append(
|
|
||||||
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
|
|
||||||
) # shape: (b, topk)
|
|
||||||
else:
|
|
||||||
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
|
|
||||||
batch.spec_info.hidden_states = (
|
|
||||||
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
batch.input_ids = topk_index.flatten()
|
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
|
|
||||||
|
|
||||||
self.scores = topk_p # shape: (b, topk)
|
|
||||||
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
|
|
||||||
self.token_list.append(topk_index) # shape: (b, topk)
|
|
||||||
self.origin_score_list.append(topk_p)
|
|
||||||
self.parents_list.append(
|
|
||||||
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
|
|
||||||
.unsqueeze(0)
|
|
||||||
.repeat(self.scores.shape[0], 1)
|
|
||||||
) # shape: (b, topk + 1)
|
|
||||||
self.cache_list.append(batch.out_cache_loc)
|
|
||||||
self.positions = (
|
|
||||||
batch.seq_lens[:, None]
|
|
||||||
+ torch.full(
|
|
||||||
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
|
|
||||||
)
|
|
||||||
).flatten()
|
|
||||||
|
|
||||||
bs = len(batch.seq_lens)
|
|
||||||
assign_req_to_token_pool[(bs,)](
|
|
||||||
batch.req_pool_indices,
|
|
||||||
batch.req_to_token_pool.req_to_token,
|
|
||||||
batch.seq_lens + self.topk * self.iter,
|
|
||||||
batch.seq_lens + self.topk * (self.iter + 1),
|
|
||||||
batch.out_cache_loc,
|
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
|
||||||
triton.next_power_of_2(bs),
|
|
||||||
)
|
|
||||||
self.iter += 1
|
|
||||||
|
|
||||||
def prepare_extend_after_decode(self, batch: ScheduleBatch):
|
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
|
||||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
|
||||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
|
||||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
|
||||||
seq_lens_cpu = batch.seq_lens.tolist()
|
|
||||||
|
|
||||||
pt = 0
|
|
||||||
i = 0
|
|
||||||
for req in batch.reqs:
|
|
||||||
if req.finished():
|
|
||||||
continue
|
|
||||||
# assert seq_len - pre_len == req.extend_input_len
|
|
||||||
input_len = batch.extend_lens[i]
|
|
||||||
seq_len = seq_lens_cpu[i]
|
|
||||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
seq_len - input_len : seq_len
|
|
||||||
] = batch.out_cache_loc[pt : pt + input_len]
|
|
||||||
pt += input_len
|
|
||||||
i += 1
|
|
||||||
assert pt == batch.out_cache_loc.shape[0]
|
|
||||||
|
|
||||||
self.positions = torch.empty_like(self.verified_id)
|
|
||||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
|
||||||
self.accept_length.add_(1)
|
|
||||||
|
|
||||||
create_extend_spec_info[(self.accept_length.numel(),)](
|
|
||||||
self.verified_id,
|
|
||||||
batch.seq_lens,
|
|
||||||
self.accept_length,
|
|
||||||
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
|
||||||
self.positions,
|
|
||||||
new_verified_id,
|
|
||||||
triton.next_power_of_2(self.spec_steps + 1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.seq_lens_sum = sum(seq_lens_cpu)
|
else:
|
||||||
batch.input_ids = self.verified_id
|
# The later decode steps
|
||||||
self.verified_id = new_verified_id
|
expand_scores = torch.mul(
|
||||||
|
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
||||||
|
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
topk_cs_p, topk_cs_index = fast_topk(
|
||||||
score_list = torch.cat(self.score_list, dim=1).flatten(
|
expand_scores.flatten(start_dim=1), topk, dim=-1
|
||||||
1
|
) # (b, topk)
|
||||||
) # b, n, topk; n= 1+(self.iter-1)*self.topk
|
scores = topk_cs_p # shape: (b, topk)
|
||||||
ss_token_list = torch.cat(
|
|
||||||
self.token_list, dim=1
|
|
||||||
) # b, (self.topk+(self.iter-1)*self.topk)
|
|
||||||
origin_token_list = torch.cat(self.origin_score_list, dim=1)
|
|
||||||
top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
|
|
||||||
top_scores_index = top_scores.indices
|
|
||||||
top_scores_index = torch.sort(top_scores_index).values
|
|
||||||
|
|
||||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
topk_index = topk_index.reshape(-1, topk**2)
|
||||||
scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
|
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
||||||
draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
|
|
||||||
parent_list = torch.cat(self.parents_list[:-1], dim=1)
|
|
||||||
|
|
||||||
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
|
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||||
parent_list,
|
0, hidden_states.shape[0], step=topk, device="cuda"
|
||||||
top_scores_index,
|
).repeat_interleave(topk)
|
||||||
batch.seq_lens,
|
hidden_states = hidden_states[selected_input_index, :]
|
||||||
self.topk,
|
|
||||||
self.iter - 1,
|
tree_info = (
|
||||||
self.num_verify_token,
|
expand_scores, # shape: (b, topk, topk)
|
||||||
|
topk_index, # shape: (b, topk * topk)
|
||||||
|
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
|
||||||
)
|
)
|
||||||
|
|
||||||
return EagleVerifyInput(
|
return input_ids, hidden_states, scores, tree_info
|
||||||
draft_tokens.flatten(),
|
|
||||||
scores.flatten(),
|
|
||||||
tree_mask,
|
|
||||||
position,
|
|
||||||
retrive_index,
|
|
||||||
retrive_cum_len,
|
|
||||||
self.num_verify_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_attn_arg_decode(
|
|
||||||
self,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
paged_kernel_lens: torch.Tensor,
|
|
||||||
req_to_token: torch.Tensor,
|
|
||||||
):
|
|
||||||
seq_num = req_pool_indices.numel()
|
|
||||||
bs = self.topk * req_pool_indices.numel()
|
|
||||||
seq_len = self.positions.reshape(-1).contiguous()
|
|
||||||
|
|
||||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
|
||||||
cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
|
|
||||||
total_len = torch.sum(paged_kernel_lens).item()
|
|
||||||
|
|
||||||
kv_indices = torch.empty(
|
|
||||||
(total_len * self.topk + seq_num * self.iter * self.topk,),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
|
|
||||||
req_pool_indices,
|
|
||||||
req_to_token,
|
|
||||||
paged_kernel_lens,
|
|
||||||
kv_indices,
|
|
||||||
self.iter,
|
|
||||||
self.topk,
|
|
||||||
req_to_token.shape[1],
|
|
||||||
triton.next_power_of_2(seq_num),
|
|
||||||
triton.next_power_of_2(self.spec_steps),
|
|
||||||
)
|
|
||||||
return bs, kv_indices, cum_kv_seq_len
|
|
||||||
|
|
||||||
def clear_draft_cache(self, batch):
|
|
||||||
draft_cache = torch.cat(self.cache_list, dim=0)
|
|
||||||
batch.token_to_kv_pool.free(draft_cache)
|
|
||||||
|
|
||||||
def generate_attn_arg_prefill(
|
|
||||||
self,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
paged_kernel_lens: torch.Tensor,
|
|
||||||
req_to_token: torch.Tensor,
|
|
||||||
):
|
|
||||||
bs = self.accept_length.numel()
|
|
||||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
|
||||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
|
||||||
|
|
||||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
|
||||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
|
||||||
req_to_token,
|
|
||||||
req_pool_indices,
|
|
||||||
paged_kernel_lens,
|
|
||||||
cum_kv_seq_len,
|
|
||||||
None,
|
|
||||||
kv_indices,
|
|
||||||
req_to_token.size(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
|
||||||
|
|
||||||
def merge_batch(self, spec_info: EAGLEDraftInput):
|
|
||||||
if self.hidden_states is None:
|
|
||||||
self.hidden_states = spec_info.hidden_states
|
|
||||||
self.verified_id = spec_info.verified_id
|
|
||||||
self.sample_output = spec_info.sample_output
|
|
||||||
self.prev_mode = spec_info.prev_mode
|
|
||||||
return
|
|
||||||
if spec_info.hidden_states is None:
|
|
||||||
return
|
|
||||||
self.hidden_states = torch.cat(
|
|
||||||
[self.hidden_states, spec_info.hidden_states], axis=0
|
|
||||||
)
|
|
||||||
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
|
||||||
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
|
|
||||||
|
|
||||||
|
|
||||||
class EagleVerifyInput(SpecInfo):
|
def fast_topk(values, topk, dim):
|
||||||
def __init__(
|
if topk == 1:
|
||||||
self,
|
# Use max along the specified dimension to get both value and index
|
||||||
draft_token: torch.Tensor,
|
max_value, max_index = torch.max(values, dim=dim)
|
||||||
draft_score: torch.Tensor,
|
return max_value.unsqueeze(1), max_index.unsqueeze(1)
|
||||||
tree_mask: torch.Tensor,
|
else:
|
||||||
positions: torch.Tensor,
|
# Use topk for efficiency with larger k values
|
||||||
retrive_index: torch.Tensor,
|
return torch.topk(values, topk, dim=dim)
|
||||||
retrive_cum_len: torch.Tensor,
|
|
||||||
draft_token_num: int,
|
|
||||||
):
|
|
||||||
self.draft_token = draft_token
|
|
||||||
self.draft_score = draft_score
|
|
||||||
self.custom_mask = tree_mask
|
|
||||||
self.positions = positions
|
|
||||||
self.retrive_index = retrive_index
|
|
||||||
self.retrive_cum_len = retrive_cum_len
|
|
||||||
self.draft_token_num = draft_token_num
|
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
|
||||||
batch.input_ids = self.draft_token
|
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
|
||||||
bs = batch.seq_lens.numel()
|
|
||||||
assign_req_to_token_pool[(bs,)](
|
|
||||||
batch.req_pool_indices,
|
|
||||||
batch.req_to_token_pool.req_to_token,
|
|
||||||
batch.seq_lens,
|
|
||||||
batch.seq_lens + self.draft_token_num,
|
|
||||||
batch.out_cache_loc,
|
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
|
||||||
triton.next_power_of_2(bs),
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_attn_arg_prefill(
|
|
||||||
self,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
paged_kernel_lens: torch.Tensor,
|
|
||||||
req_to_token: torch.Tensor,
|
|
||||||
):
|
|
||||||
batch_size = len(req_pool_indices)
|
|
||||||
qo_indptr = torch.arange(
|
|
||||||
0,
|
|
||||||
(1 + batch_size) * self.draft_token_num,
|
|
||||||
step=self.draft_token_num,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
cum_kv_seq_len = torch.zeros(
|
|
||||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
|
||||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
|
|
||||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(batch_size,)](
|
|
||||||
req_to_token,
|
|
||||||
req_pool_indices,
|
|
||||||
paged_kernel_lens,
|
|
||||||
cum_kv_seq_len,
|
|
||||||
None,
|
|
||||||
kv_indices,
|
|
||||||
req_to_token.size(1),
|
|
||||||
)
|
|
||||||
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
|
||||||
|
|
||||||
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
|
|
||||||
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
|
||||||
predict = torch.cat(
|
|
||||||
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
|
|
||||||
)
|
|
||||||
draft_token = torch.cat(
|
|
||||||
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
target_predict = predict[self.retrive_index]
|
|
||||||
candidates = draft_token[self.retrive_index]
|
|
||||||
# logits = logits_output.next_token_logits[self.retrive_index]
|
|
||||||
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
|
|
||||||
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
|
|
||||||
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
|
|
||||||
bs = self.retrive_cum_len.numel() - 1
|
|
||||||
|
|
||||||
max_draft_len = self.retrive_index.shape[-1]
|
|
||||||
accept_index = torch.full(
|
|
||||||
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
|
|
||||||
)
|
|
||||||
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
|
|
||||||
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
|
|
||||||
eagle_verify_retrive[(bs,)](
|
|
||||||
self.retrive_index.contiguous(),
|
|
||||||
accept_mask.contiguous(),
|
|
||||||
self.retrive_cum_len,
|
|
||||||
accept_index,
|
|
||||||
accept_length,
|
|
||||||
extract_index,
|
|
||||||
max_draft_len,
|
|
||||||
self.draft_token_num,
|
|
||||||
triton.next_power_of_2(max_draft_len),
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_input = EAGLEDraftInput()
|
|
||||||
new_accept_index = []
|
|
||||||
unfinished_index = []
|
|
||||||
finished_extend_len = {} # {rid:accept_length + 1}
|
|
||||||
accept_index_cpu = accept_index.tolist()
|
|
||||||
predict_cpu = predict.tolist()
|
|
||||||
has_finished = False
|
|
||||||
|
|
||||||
# iterate every accepted token and check if req has finished after append the token
|
|
||||||
# should be checked BEFORE free kv cache slots
|
|
||||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
|
||||||
new_accept_index_ = []
|
|
||||||
for j, idx in enumerate(accept_index_row):
|
|
||||||
if idx == -1:
|
|
||||||
break
|
|
||||||
id = predict_cpu[idx]
|
|
||||||
# if not found_finished:
|
|
||||||
req.output_ids.append(id)
|
|
||||||
finished_extend_len[req.rid] = j + 1
|
|
||||||
req.check_finished()
|
|
||||||
if req.finished():
|
|
||||||
has_finished = True
|
|
||||||
# set all tokens after finished token to -1 and break
|
|
||||||
accept_index[i, j + 1 :] = -1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
new_accept_index_.append(idx)
|
|
||||||
if not req.finished():
|
|
||||||
new_accept_index.extend(new_accept_index_)
|
|
||||||
unfinished_index.append(i)
|
|
||||||
req.spec_verify_ct += 1
|
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
|
||||||
|
|
||||||
accept_index = accept_index[accept_index != -1]
|
|
||||||
accept_length_cpu = accept_length.tolist()
|
|
||||||
verified_id = predict[accept_index]
|
|
||||||
|
|
||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
|
||||||
evict_mask[accept_index] = False
|
|
||||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
|
||||||
batch.token_to_kv_pool.free(mem_need_free_idx)
|
|
||||||
assign_req_to_token_pool[(bs,)](
|
|
||||||
batch.req_pool_indices,
|
|
||||||
batch.req_to_token_pool.req_to_token,
|
|
||||||
batch.seq_lens,
|
|
||||||
batch.seq_lens + accept_length + 1,
|
|
||||||
batch.out_cache_loc[accept_index],
|
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
|
||||||
triton.next_power_of_2(bs),
|
|
||||||
)
|
|
||||||
batch.seq_lens.add_(accept_length + 1)
|
|
||||||
|
|
||||||
if len(new_accept_index) > 0:
|
|
||||||
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
|
||||||
draft_input.verified_id = predict[new_accept_index]
|
|
||||||
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
|
||||||
draft_input.accept_length = accept_length[unfinished_index]
|
|
||||||
draft_input.accept_length_cpu = [
|
|
||||||
accept_length_cpu[i] for i in unfinished_index
|
|
||||||
]
|
|
||||||
if has_finished:
|
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
|
||||||
else:
|
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
|
||||||
|
|
||||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
|
||||||
return (
|
|
||||||
draft_input,
|
|
||||||
logits_output,
|
|
||||||
verified_id,
|
|
||||||
finished_extend_len,
|
|
||||||
accept_length_cpu,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -12,8 +14,18 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||||
from sglang.srt.utils import rank0_print
|
EAGLEDraftCudaGraphRunner,
|
||||||
|
)
|
||||||
|
from sglang.srt.speculative.eagle_utils import (
|
||||||
|
EagleDraftInput,
|
||||||
|
EagleVerifyInput,
|
||||||
|
assign_draft_cache_locs,
|
||||||
|
fast_topk,
|
||||||
|
select_top_k_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EAGLEWorker(TpModelWorker):
|
class EAGLEWorker(TpModelWorker):
|
||||||
@@ -40,41 +52,47 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
is_draft_worker=True,
|
is_draft_worker=True,
|
||||||
)
|
)
|
||||||
self.target_worker = target_worker
|
self.target_worker = target_worker
|
||||||
self.server_args = server_args
|
|
||||||
self.finish_extend_len = []
|
self.finish_extend_len = []
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
self.topk = server_args.speculative_eagle_topk
|
||||||
|
self.speculative_num_steps = server_args.speculative_num_steps
|
||||||
|
self.server_args = server_args
|
||||||
|
|
||||||
# Share the embedding and lm_head
|
# Share the embedding and lm_head
|
||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
self.model_runner.model.set_embed_and_head(embed, head)
|
self.model_runner.model.set_embed_and_head(embed, head)
|
||||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||||
self.model_runner.init_cuda_graphs()
|
|
||||||
|
|
||||||
def forward_draft_decode(self, batch: ScheduleBatch):
|
# Create multi-step attn backends and cuda graph runners
|
||||||
batch.spec_info.prepare_for_decode(batch)
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
FlashInferMultiStepDraftBackend,
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
)
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
|
||||||
|
|
||||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||||
self._set_mem_pool(batch, self.model_runner)
|
self.model_runner,
|
||||||
batch.spec_info.prepare_for_extend(batch)
|
self.topk,
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
self.speculative_num_steps,
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
)
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
self.init_cuda_graphs()
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
|
||||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
def init_cuda_graphs(self):
|
||||||
|
"""Capture cuda graphs."""
|
||||||
|
self.cuda_graph_runner = None
|
||||||
|
|
||||||
|
if self.server_args.disable_cuda_graph:
|
||||||
|
return
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||||
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||||
|
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||||
|
|
||||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
# Draft
|
# Draft
|
||||||
self._set_mem_pool(batch, self.model_runner)
|
spec_info: EagleVerifyInput = self.draft(batch)
|
||||||
for i in range(self.server_args.speculative_num_steps):
|
|
||||||
self.forward_draft_decode(batch)
|
|
||||||
batch.spec_info.clear_draft_cache(batch)
|
|
||||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
(
|
(
|
||||||
@@ -84,8 +102,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.finish_extend_len,
|
self.finish_extend_len,
|
||||||
accept_length_cpu,
|
accept_length_cpu,
|
||||||
model_worker_batch,
|
model_worker_batch,
|
||||||
) = self.verify(batch)
|
) = self.verify(batch, spec_info)
|
||||||
next_draft_input.load_server_args(self.server_args)
|
|
||||||
batch.spec_info = next_draft_input
|
batch.spec_info = next_draft_input
|
||||||
# if it is None, means all requsets are finished
|
# if it is None, means all requsets are finished
|
||||||
if batch.spec_info.verified_id is not None:
|
if batch.spec_info.verified_id is not None:
|
||||||
@@ -107,29 +124,145 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Forward with the draft model.
|
# Forward with the draft model.
|
||||||
spec_info = EAGLEDraftInput()
|
batch.spec_info = EagleDraftInput(
|
||||||
spec_info.load_server_args(self.server_args)
|
hidden_states=logits_output.hidden_states,
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
verified_id=next_token_ids,
|
||||||
spec_info.verified_id = next_token_ids
|
)
|
||||||
batch.spec_info = spec_info
|
|
||||||
self.forward_draft_extend(batch)
|
self.forward_draft_extend(batch)
|
||||||
return logits_output, next_token_ids, model_worker_batch, 0
|
return logits_output, next_token_ids, model_worker_batch, 0
|
||||||
|
|
||||||
def verify(self, batch: ScheduleBatch):
|
def draft(self, batch: ScheduleBatch):
|
||||||
verify_input = batch.spec_info.prepare_for_verify(batch)
|
self._set_mem_pool(batch, self.model_runner)
|
||||||
verify_input.prepare_for_verify(batch)
|
|
||||||
|
# Parse args
|
||||||
|
num_seqs = batch.batch_size()
|
||||||
|
spec_info = batch.spec_info
|
||||||
|
|
||||||
|
# Allocate cache locations
|
||||||
|
out_cache_loc = batch.alloc_token_slots(
|
||||||
|
num_seqs * self.topk * self.speculative_num_steps
|
||||||
|
)
|
||||||
|
assign_draft_cache_locs[(num_seqs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.seq_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.out_cache_loc = out_cache_loc
|
||||||
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||||
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||||
|
|
||||||
|
# Get forward batch
|
||||||
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_cuda_graph:
|
||||||
|
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Initialize attention backend
|
||||||
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
# Run forward steps
|
||||||
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||||
|
|
||||||
|
ret = EagleVerifyInput.create(
|
||||||
|
spec_info.verified_id,
|
||||||
|
score_list,
|
||||||
|
token_list,
|
||||||
|
parents_list,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens_sum,
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
self.server_args.speculative_num_draft_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Free cache locations
|
||||||
|
batch.token_to_kv_pool.free(out_cache_loc)
|
||||||
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def draft_forward(self, forward_batch: ForwardBatch):
|
||||||
|
# Parse args
|
||||||
|
spec_info = forward_batch.spec_info
|
||||||
|
out_cache_loc = forward_batch.out_cache_loc
|
||||||
|
topk_p, topk_index, hidden_states = (
|
||||||
|
spec_info.topk_p,
|
||||||
|
spec_info.topk_index,
|
||||||
|
spec_info.hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return values
|
||||||
|
score_list: List[torch.Tensor] = []
|
||||||
|
token_list: List[torch.Tensor] = []
|
||||||
|
parents_list: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Forward multiple steps
|
||||||
|
scores = None
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
|
||||||
|
i, topk_p, topk_index, hidden_states, scores, self.topk
|
||||||
|
)
|
||||||
|
score_list.append(tree_info[0])
|
||||||
|
token_list.append(tree_info[1])
|
||||||
|
parents_list.append(tree_info[2])
|
||||||
|
|
||||||
|
# Set inputs
|
||||||
|
forward_batch.input_ids = input_ids
|
||||||
|
forward_batch.out_cache_loc = out_cache_loc[
|
||||||
|
forward_batch.batch_size
|
||||||
|
* self.topk
|
||||||
|
* i : forward_batch.batch_size
|
||||||
|
* self.topk
|
||||||
|
* (i + 1)
|
||||||
|
]
|
||||||
|
forward_batch.positions.add_(1)
|
||||||
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||||
|
spec_info.hidden_states = hidden_states
|
||||||
|
|
||||||
|
# Run forward
|
||||||
|
logits_output = self.model_runner.model.forward(
|
||||||
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
|
)
|
||||||
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
|
hidden_states = logits_output.hidden_states
|
||||||
|
|
||||||
|
return score_list, token_list, parents_list
|
||||||
|
|
||||||
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||||
|
spec_info.prepare_for_verify(batch)
|
||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = verify_input
|
batch.spec_info = spec_info
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, _ = self.target_worker.forward_batch_generation(
|
logits_output, _ = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch, skip_sample=True
|
model_worker_batch, skip_sample=True
|
||||||
)
|
)
|
||||||
verify_input.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
res = verify_input.verify(batch, logits_output)
|
res = spec_info.verify(batch, logits_output)
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
return res + (model_worker_batch,)
|
return res + (model_worker_batch,)
|
||||||
|
|
||||||
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||||
|
self._set_mem_pool(batch, self.model_runner)
|
||||||
|
batch.spec_info.prepare_for_extend(batch)
|
||||||
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
|
||||||
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||||
batch.token_to_kv_pool = runner.token_to_kv_pool
|
batch.token_to_kv_pool = runner.token_to_kv_pool
|
||||||
batch.req_to_token_pool = runner.req_to_token_pool
|
batch.req_to_token_pool = runner.req_to_token_pool
|
||||||
@@ -139,7 +272,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
self._set_mem_pool(batch, self.model_runner)
|
self._set_mem_pool(batch, self.model_runner)
|
||||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||||
batch.spec_info.prepare_extend_after_decode(batch)
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
@@ -155,13 +288,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
def capture_for_decode(
|
def capture_for_decode(
|
||||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||||
):
|
):
|
||||||
sample_output = torch.softmax(
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
logits_output.next_token_logits, dim=-1
|
|
||||||
) # TODO(kavioyu): Support more sampling methods
|
|
||||||
spec_info = forward_batch.spec_info
|
spec_info = forward_batch.spec_info
|
||||||
spec_info.sample_output = sample_output
|
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
spec_info.prev_mode = forward_batch.forward_mode
|
|
||||||
|
|
||||||
# Don't support prefix share now.
|
# Don't support prefix share now.
|
||||||
def finish_request(self, reqs: Union[Req, List[Req]]):
|
def finish_request(self, reqs: Union[Req, List[Req]]):
|
||||||
|
|||||||
Reference in New Issue
Block a user