From 013021b6a1c3a95fb9569ff730d047c960c78380 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 3 Feb 2025 20:52:30 +0800 Subject: [PATCH] refactor EAGLE 2 (#3269) Co-authored-by: Ying Sheng Co-authored-by: merrymercy Co-authored-by: Ying1123 --- .../engine/EAGLE_offline_batch_inference.py | 1 + .../layers/attention/flashinfer_backend.py | 333 ++++++- .../srt/model_executor/cuda_graph_runner.py | 155 ++- .../srt/model_executor/forward_batch_info.py | 117 ++- .../sglang/srt/model_executor/model_runner.py | 3 +- .../srt/speculative/build_eagle_tree.py | 6 +- .../eagle_draft_cuda_graph_runner.py | 213 ++++ python/sglang/srt/speculative/eagle_utils.py | 910 +++++++++--------- python/sglang/srt/speculative/eagle_worker.py | 220 ++++- 9 files changed, 1271 insertions(+), 687 deletions(-) create mode 100644 python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py diff --git a/examples/runtime/engine/EAGLE_offline_batch_inference.py b/examples/runtime/engine/EAGLE_offline_batch_inference.py index 0885959b3..897d50ae2 100644 --- a/examples/runtime/engine/EAGLE_offline_batch_inference.py +++ b/examples/runtime/engine/EAGLE_offline_batch_inference.py @@ -21,6 +21,7 @@ def main(): speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, + cuda_graph_max_bs=8, ) outputs = llm.generate(prompts, sampling_params) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index cc6da781f..863cb031d 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an import os from dataclasses import dataclass from enum import Enum, auto +from functools import partial from typing import TYPE_CHECKING, List, Optional, Union import torch @@ -34,6 +35,7 @@ if is_flashinfer_available(): BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state + from flashinfer.decode import PosEncodingMode class WrapperDispatch(Enum): @@ -53,10 +55,19 @@ class PrefillMetadata: extend_no_prefix: bool +# Reuse this workspace buffer across all flashinfer wrappers +global_workspace_buffer = None + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" - def __init__(self, model_runner: ModelRunner): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): super().__init__() # Parse constants @@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend): ), ) self.max_context_len = model_runner.model_config.context_len + self.skip_prefill = skip_prefill assert not ( model_runner.sliding_window_size is not None @@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend): global_config.flashinfer_workspace_size = 512 * 1024 * 1024 # Allocate buffers - self.workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device=model_runner.device, - ) + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size - self.kv_indptr = [ - torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - for _ in range(self.num_wrappers) - ] + if kv_indptr_buf is None: + self.kv_indptr = [ + torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + for _ in range(self.num_wrappers) + ] + else: + assert self.num_wrappers == 1 + self.kv_indptr = [kv_indptr_buf] + self.kv_last_page_len = torch.ones( (max_bs,), dtype=torch.int32, device=model_runner.device ) @@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend): self.prefill_wrappers_verify = [] self.decode_wrappers = [] for _ in range(self.num_wrappers): - self.prefill_wrappers_paged.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) - self.prefill_wrappers_verify.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) + if not skip_prefill: + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + ) + ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend): ) # Create indices updater + if not skip_prefill: + self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( + model_runner, self + ) self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) - self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( - model_runner, self - ) # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None @@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend): self.prefill_wrappers_paged, use_ragged, extend_no_prefix ) - def init_cuda_graph_state(self, max_bs: int): - cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len,), - dtype=torch.int32, - device="cuda", - ) + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): + if kv_indices_buf is None: + cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len,), + dtype=torch.int32, + device="cuda", + ) + else: + cuda_graph_kv_indices = kv_indices_buf + self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device="cuda", - ) - self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] - self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] def init_forward_metadata_capture_cuda_graph( self, @@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode: self.req_to_token.shape[1], ) else: - bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( - req_pool_indices, - paged_kernel_lens, - self.req_to_token, - ) + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 wrapper.end_forward() wrapper.begin_forward( @@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill: ) +class FlashInferMultiStepDraftBackend: + """ + Wrap multiple flashinfer attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashInferAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + ) + ) + self.max_context_len = self.attn_backends[0].max_context_len + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + self.kv_indptr_stride = self.kv_indptr.shape[1] + + def common_template(self, forward_batch: ForwardBatch, call_fn: int): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + self.cuda_graph_kv_indices, + self.kv_indptr, + forward_batch.positions, + num_seqs, + self.topk, + self.pool_len, + self.kv_indptr_stride, + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + ) + for i in range(self.speculative_num_steps): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_bs * self.max_context_len), + dtype=torch.int32, + device="cuda", + ) + self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1] + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[ + forward_batch.batch_size + ][0] + decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper) + + self.common_template(forward_batch, call_fn) + + def init_forward_metadata_replay_cuda_graph(self, forward_batch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + forward_batch.batch_size, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, call_fn) + + @triton.jit def create_flashinfer_kv_indices_triton( req_to_token_ptr, # [max_batch, max_context_len] @@ -937,3 +1094,105 @@ def should_use_tensor_core( return gqa_group_size > 4 else: return False + + +def fast_decode_plan( + self, + indptr: torch.Tensor, + indices: torch.Tensor, + last_page_len: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + data_type: Union[str, torch.dtype] = "float16", + q_data_type: Optional[Union[str, torch.dtype]] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> None: + """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" + batch_size = len(last_page_len) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + if self.is_cuda_graph_enabled: + if batch_size != self._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} " + " mismatches the batch size set during initialization {}".format( + batch_size, self._fixed_batch_size + ) + ) + if len(indices) > len(self._paged_kv_indices_buf): + raise ValueError( + "The size of indices should be less than or equal to the allocated buffer" + ) + else: + self._paged_kv_indptr_buf = indptr + self._paged_kv_indices_buf = indices + self._paged_kv_last_page_len_buf = last_page_len + # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info + if not q_data_type: + q_data_type = data_type + if not hasattr(self, "empty_q_data"): + self.empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type + ), + ) + self.empty_kv_cache = torch.empty( + 0, + dtype=( + getattr(torch, data_type) if isinstance(data_type, str) else data_type + ), + ) + self.last_page_len = torch.ones(32768, dtype=torch.int32) + empty_q_data = self.empty_q_data + empty_kv_cache = self.empty_kv_cache + if self.use_tensor_cores: + if not self.is_cuda_graph_enabled: + # when not using cudagraph, we need to create the indptr buffer, otherwise + # the buffer is already created during initialization + self._qo_indptr_buf = torch.arange( + batch_size + 1, dtype=torch.int32, device=indptr.device + ) + self._wrapper.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._qo_indptr_buf, + indptr, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + empty_q_data, + ) + else: + self._wrapper.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + indptr, + self.last_page_len, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + PosEncodingMode[pos_encoding_mode].value, + logits_soft_cap, + empty_q_data, + empty_kv_cache, + ) + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 69615b8ff..1f5e8e851 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -103,69 +103,75 @@ def set_torch_compile_config(): torch._dynamo.config.cache_size_limit = 1024 +def get_batch_sizes_to_capture(model_runner: ModelRunner): + server_args = model_runner.server_args + capture_bs = server_args.cuda_graph_bs + if capture_bs is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + [64, 128] + else: + capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + if max(capture_bs) > model_runner.req_to_token_pool.size: + # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very samll. We add more values here to make sure we capture the maximum bs. + capture_bs = list( + sorted( + set( + capture_bs + + [model_runner.req_to_token_pool.size - 1] + + [model_runner.req_to_token_pool.size] + ) + ) + ) + capture_bs = [ + bs + for bs in capture_bs + if bs <= model_runner.req_to_token_pool.size + and bs <= server_args.cuda_graph_max_bs + ] + compile_bs = ( + [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + if server_args.enable_torch_compile + else [] + ) + return capture_bs, compile_bs + + +# Reuse this memory pool across all cuda graph runners. +global_graph_memory_pool = None + + +def get_global_graph_memory_pool(): + return global_graph_memory_pool + + +def set_global_graph_memory_pool(val): + global global_graph_memory_pool + global_graph_memory_pool = val + + class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" - def __init__(self, model_runner: "ModelRunner"): + def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner self.graphs = {} - self.input_buffers = {} self.output_buffers = {} - self.flashinfer_handlers = {} - self.graph_memory_pool = None - self.use_torch_compile = model_runner.server_args.enable_torch_compile + self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding - self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder - self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention - self.tp_size = self.model_runner.tp_size - self.dp_size = self.model_runner.server_args.dp_size + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size # Batch sizes to capture - self.capture_bs = self.model_runner.server_args.cuda_graph_bs - if self.capture_bs is None: - if model_runner.server_args.disable_cuda_graph_padding: - self.capture_bs = list(range(1, 33)) + [64, 128] - else: - self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - - if max(self.capture_bs) > model_runner.req_to_token_pool.size: - # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very samll. We add more values here to make sure we capture the maximum bs. - self.capture_bs = list( - sorted( - set( - self.capture_bs - + [model_runner.req_to_token_pool.size - 1] - + [model_runner.req_to_token_pool.size] - ) - ) - ) - - self.capture_bs = [ - bs - for bs in self.capture_bs - if bs <= model_runner.req_to_token_pool.size - and bs <= model_runner.server_args.cuda_graph_max_bs - ] - - self.compile_bs = ( - [ - bs - for bs in self.capture_bs - if bs <= self.model_runner.server_args.torch_compile_max_bs - ] - if self.use_torch_compile - else [] - ) - + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_eagle_topk - ) + raise RuntimeError("This should not happen") else: self.capture_forward_mode = ForwardMode.TARGET_VERIFY self.num_tokens_per_bs = ( @@ -182,10 +188,10 @@ class CudaGraphRunner: # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 - if self.use_torch_compile: + if self.enable_torch_compile: set_torch_compile_config() - # Common inputs + # Graph inputs with torch.device("cuda"): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) @@ -301,7 +307,7 @@ class CudaGraphRunner: stream = self.stream num_tokens = bs * self.num_tokens_per_bs - # Common inputs + # Graph inputs input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] @@ -320,7 +326,7 @@ class CudaGraphRunner: global_num_tokens = None gathered_buffer = None - spec_info = self.get_spec_info(num_tokens, positions) + spec_info = self.get_spec_info(num_tokens) forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, @@ -335,7 +341,6 @@ class CudaGraphRunner: seq_lens_sum=seq_lens.sum(), encoder_lens=encoder_lens, return_logprob=False, - top_logprobs_nums=[0] * bs, positions=positions, global_num_tokens=global_num_tokens, gathered_buffer=gathered_buffer, @@ -375,13 +380,14 @@ class CudaGraphRunner: torch.cuda.synchronize() self.model_runner.tp_group.barrier() - with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): + global global_graph_memory_pool + with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): out = run_once() torch.cuda.synchronize() self.model_runner.tp_group.barrier() - self.graph_memory_pool = graph.pool() + global_graph_memory_pool = graph.pool() return graph, out def replay(self, forward_batch: ForwardBatch): @@ -439,35 +445,26 @@ class CudaGraphRunner: ) return logits_output - def get_spec_info(self, num_tokens: int, positions: torch.Tensor): + def get_spec_info(self, num_tokens: int): spec_info = None if self.model_runner.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_utils import ( - EAGLEDraftInput, - EagleVerifyInput, - ) + from sglang.srt.speculative.eagle_utils import EagleVerifyInput if self.model_runner.is_draft_worker: - spec_info = EAGLEDraftInput() - spec_info.load_server_args(self.model_runner.server_args) - spec_info.hidden_states = self.hidden_states[:num_tokens] - spec_info.positions = positions - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + raise RuntimeError("This should not happen.") else: spec_info = EagleVerifyInput( - None, - None, - None, - None, - None, - None, - self.model_runner.server_args.speculative_num_draft_tokens, + draft_token=None, + custom_mask=torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ), + positions=None, + retrive_index=None, + retrive_cum_len=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, ) - spec_info.custom_mask = torch.zeros( - (num_tokens * self.model_runner.model_config.context_len), - dtype=torch.bool, - device="cuda", - ) - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL return spec_info diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8bd105275..b36dedc9f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -197,64 +197,6 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None - def compute_mrope_positions( - self, model_runner: ModelRunner, batch: ModelWorkerBatch - ): - device = model_runner.device - hf_config = model_runner.model_config.hf_config - mrope_positions_list = [None] * self.seq_lens.shape[0] - if self.forward_mode.is_decode(): - for i, _ in enumerate(mrope_positions_list): - mrope_position_delta = ( - 0 - if batch.image_inputs[i] is None - else batch.image_inputs[i].mrope_position_delta - ) - mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, - int(self.seq_lens[i]) - 1, - int(self.seq_lens[i]), - ) - elif self.forward_mode.is_extend(): - extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() - for i, image_inputs in enumerate(batch.image_inputs): - extend_start_loc, extend_seq_len, extend_prefix_len = ( - extend_start_loc_cpu[i], - batch.extend_seq_lens[i], - batch.extend_prefix_lens[i], - ) - if image_inputs is None: - # text only - mrope_positions = [ - [ - pos - for pos in range( - extend_prefix_len, extend_prefix_len + extend_seq_len - ) - ] - ] * 3 - else: - # TODO: current qwen2-vl do not support radix cache since mrope position calculation - mrope_positions, mrope_position_delta = ( - MRotaryEmbedding.get_input_positions( - input_tokens=self.input_ids[ - extend_start_loc : extend_start_loc + extend_seq_len - ], - image_grid_thw=image_inputs.image_grid_thws, - vision_start_token_id=hf_config.vision_start_token_id, - spatial_merge_size=hf_config.vision_config.spatial_merge_size, - context_len=0, - ) - ) - batch.image_inputs[i].mrope_position_delta = mrope_position_delta - mrope_positions_list[i] = mrope_positions - - self.mrope_positions = torch.concat( - [torch.tensor(pos, device=device) for pos in mrope_positions_list], - axis=1, - ) - self.mrope_positions = self.mrope_positions.to(torch.int64) - @classmethod def init_new( cls, @@ -337,7 +279,7 @@ class ForwardBatch: ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens if model_runner.model_is_mrope: - ret.compute_mrope_positions(model_runner, batch) + ret._compute_mrope_positions(model_runner, batch) # Init lora information if model_runner.server_args.lora_paths is not None: @@ -345,6 +287,63 @@ class ForwardBatch: return ret + def _compute_mrope_positions( + self, model_runner: ModelRunner, batch: ModelWorkerBatch + ): + device = model_runner.device + hf_config = model_runner.model_config.hf_config + mrope_positions_list = [None] * self.seq_lens.shape[0] + if self.forward_mode.is_decode(): + for i, _ in enumerate(mrope_positions_list): + mrope_position_delta = ( + 0 + if batch.image_inputs[i] is None + else batch.image_inputs[i].mrope_position_delta + ) + mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( + mrope_position_delta, + int(self.seq_lens[i]) - 1, + int(self.seq_lens[i]), + ) + elif self.forward_mode.is_extend(): + extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() + for i, image_inputs in enumerate(batch.image_inputs): + extend_start_loc, extend_seq_len, extend_prefix_len = ( + extend_start_loc_cpu[i], + batch.extend_seq_lens[i], + batch.extend_prefix_lens[i], + ) + if image_inputs is None: + # text only + mrope_positions = [ + [ + pos + for pos in range( + extend_prefix_len, extend_prefix_len + extend_seq_len + ) + ] + ] * 3 + else: + # TODO: current qwen2-vl do not support radix cache since mrope position calculation + mrope_positions, mrope_position_delta = ( + MRotaryEmbedding.get_input_positions( + input_tokens=self.input_ids[ + extend_start_loc : extend_start_loc + extend_seq_len + ], + image_grid_thw=image_inputs.image_grid_thws, + vision_start_token_id=hf_config.vision_start_token_id, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + context_len=0, + ) + ) + batch.image_inputs[i].mrope_position_delta = mrope_position_delta + mrope_positions_list[i] = mrope_positions + self.mrope_positions = torch.concat( + [torch.tensor(pos, device=device) for pos in mrope_positions_list], + axis=1, + ) + self.mrope_positions = self.mrope_positions.to(torch.int64) + def compute_position_triton( extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6fa1429dc..5b19c77e2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import ( MLATokenToKVPool, ReqToTokenPool, ) +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs @@ -714,8 +715,6 @@ class ModelRunner: def init_cuda_graphs(self): """Capture cuda graphs.""" - from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - self.cuda_graph_runner = None if not self.is_generation: diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 6412825ed..e0ac9fe0b 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -79,11 +79,13 @@ __global__ void build_tree(Tensor parent_list, Tensor selected ) -def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token): +def build_tree_kernel( + parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token +): bs = seq_lens.numel() device = parent_list.device tree_mask = torch.full( - (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,), + (seq_lens_sum * draft_token + draft_token * draft_token * bs,), True, device=device, ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py new file mode 100644 index 000000000..41ff5c19e --- /dev/null +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import bisect +import time +from typing import TYPE_CHECKING, Callable + +import torch + +from sglang.srt.model_executor.cuda_graph_runner import ( + CudaGraphRunner, + get_batch_sizes_to_capture, + get_global_graph_memory_pool, + set_global_graph_memory_pool, + set_torch_compile_config, +) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.speculative.eagle_utils import EagleDraftInput + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + +class EAGLEDraftCudaGraphRunner: + def __init__(self, eagle_worker: EAGLEWorker): + # Parse args + self.eagle_worker = eagle_worker + self.model_runner = model_runner = eagle_worker.model_runner + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.tp_size = self.model_runner.tp_size + self.dp_size = model_runner.server_args.dp_size + self.topk = model_runner.server_args.speculative_eagle_topk + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + server_args = model_runner.server_args + + assert self.disable_padding + + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.num_tokens_per_bs = server_args.speculative_eagle_topk + + # Attention backend + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token) + self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ + 0 + ].get_cuda_graph_seq_len_fill_value() + + if self.enable_torch_compile: + set_torch_compile_config() + + # Graph inputs + with torch.device("cuda"): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + self.out_cache_loc = torch.zeros( + (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 + ) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) + self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) + self.hidden_states = torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) + + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n" + "Possible solutions:\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + + def can_run(self, forward_batch: ForwardBatch): + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + return is_bs_supported + + def capture(self): + CudaGraphRunner.capture(self) + + def capture_one_batch_size(self, num_seqs: int, forward: Callable): + graph = torch.cuda.CUDAGraph() + stream = self.stream + num_tokens = num_seqs * self.num_tokens_per_bs + + # Graph inputs + req_pool_indices = self.req_pool_indices[:num_seqs] + seq_lens = self.seq_lens[:num_seqs] + out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] + positions = self.positions[:num_tokens] + topk_p = self.topk_p[:num_seqs] + topk_index = self.topk_index[:num_seqs] + hidden_states = self.hidden_states[:num_seqs] + + spec_info = EagleDraftInput( + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + ) + + # Forward batch + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=num_seqs, + input_ids=None, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + return_logprob=False, + positions=positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ), + ) + + # Attention backend + self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph( + forward_batch + ) + + # Run and capture + def run_once(): + # Backup two fileds, which will be modified in-place in `draft_forward`. + output_cache_loc_backup = forward_batch.out_cache_loc + hidden_states_backup = forward_batch.spec_info.hidden_states + + ret = self.eagle_worker.draft_forward(forward_batch) + + forward_batch.out_cache_loc = output_cache_loc_backup + forward_batch.spec_info.hidden_states = hidden_states_backup + return ret + + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + run_once() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + set_global_graph_memory_pool(graph.pool()) + return graph, out + + def replay(self, forward_batch: ForwardBatch): + assert forward_batch.out_cache_loc is not None + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] + if bs != raw_bs: + self.seq_lens.fill_(1) + self.out_cache_loc.zero_() + + # Common inputs + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_( + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) + self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) + self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + + # Attention backend + self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( + forward_batch + ) + + # Replay + self.graphs[bs].replay() + + return self.output_buffers[bs] diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 97cdb2640..0b8c99f04 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses from typing import TYPE_CHECKING, List import torch @@ -9,13 +10,360 @@ import triton.language as tl from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel -from sglang.srt.speculative.spec_info import SpecInfo if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch - from sglang.srt.server_args import ServerArgs + + +@dataclasses.dataclass +class EagleDraftInput: + # The inputs for decode + # shape: (b, topk) + topk_p: torch.Tensor = None + topk_index: torch.Tensor = None + # shape: (b, hidden_size) + hidden_states: torch.Tensor = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + + # Inputs for extend + # shape: (b,) + verified_id: torch.Tensor = None + accept_length: torch.Tensor = None + accept_length_cpu: List[int] = None + + # Inputs for the attention backends + # shape: (b + 1,) + kv_indptr: torch.Tensor = None + kv_indices: torch.Tensor = None + + def prepare_for_extend(self, batch: ScheduleBatch): + req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.out_cache_loc = out_cache_loc + + pt = 0 + for i, req in enumerate(batch.reqs): + req.req_pool_idx = req_pool_indices[i] + pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) + assert seq_len - pre_len == req.extend_input_len + + if pre_len > 0: + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( + out_cache_loc[pt : pt + req.extend_input_len] + ) + + pt += req.extend_input_len + + # TODO: support batching inputs + assert len(batch.extend_lens) == 1 + batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) + + def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): + batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + accept_length_cpu = batch.spec_info.accept_length_cpu + batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + seq_lens_cpu = batch.seq_lens.tolist() + + pt = 0 + i = 0 + for req in batch.reqs: + if req.finished(): + continue + # assert seq_len - pre_len == req.extend_input_len + input_len = batch.extend_lens[i] + seq_len = seq_lens_cpu[i] + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + seq_len - input_len : seq_len + ] = batch.out_cache_loc[pt : pt + input_len] + pt += input_len + i += 1 + assert pt == batch.out_cache_loc.shape[0] + + self.positions = torch.empty_like(self.verified_id) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.accept_length.add_(1) + + create_extend_spec_info[(self.accept_length.numel(),)]( + self.verified_id, + batch.seq_lens, + self.accept_length, + torch.cumsum(self.accept_length, axis=0, dtype=torch.int), + self.positions, + new_verified_id, + triton.next_power_of_2(speculative_num_steps + 1), + ) + + batch.seq_lens_sum = sum(seq_lens_cpu) + batch.input_ids = self.verified_id + self.verified_id = new_verified_id + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, qo_indptr, None + + def filter_batch(self, new_indices: torch.Tensor): + self.topk_p = self.topk_p[: len(new_indices)] + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + + def merge_batch(self, spec_info: EagleDraftInput): + if self.hidden_states is None: + self.hidden_states = spec_info.hidden_states + self.verified_id = spec_info.verified_id + self.topk_p = spec_info.topk_p + self.topk_index = spec_info.topk_index + return + if spec_info.hidden_states is None: + return + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], axis=0 + ) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + + +@dataclasses.dataclass +class EagleVerifyInput: + draft_token: torch.Tensor + custom_mask: torch.Tensor + positions: torch.Tensor + retrive_index: torch.Tensor + retrive_cum_len: torch.Tensor + draft_token_num: int + capture_hidden_mode: CaptureHiddenMode + + @classmethod + def create( + cls, + verified_id: torch.Tensor, + score_list: List[torch.Tensor], + token_list: List[torch.Tensor], + parents_list: List[torch.Tensor], + seq_lens: torch.Tensor, + seq_lens_sum: int, + topk: int, + spec_steps: int, + num_verify_token: int, + ): + score_list = torch.cat(score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1 + (num_steps-1) * self.topk + ss_token_list = torch.cat( + token_list, dim=1 + ) # b, (self.topk + (num_steps-1) * self.topk) + top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1) + parent_list = torch.cat(parents_list[:-1], dim=1) + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + seq_lens, + seq_lens_sum, + topk, + spec_steps, + num_verify_token, + ) + return cls( + draft_tokens.flatten(), + tree_mask, + position, + retrive_index, + retrive_cum_len, + num_verify_token, + CaptureHiddenMode.FULL, + ) + + def prepare_for_verify(self, batch: ScheduleBatch): + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + bs = batch.batch_size() + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + self.draft_token_num, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + batch_size = len(req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + batch_size) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat( + [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1 + ) + draft_token = torch.cat( + [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")], + dim=-1, + ) + target_predict = predict[self.retrive_index] + candidates = draft_token[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + bs = self.retrive_cum_len.numel() - 1 + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + + new_accept_index = [] + unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} + accept_index_cpu = accept_index.tolist() + predict_cpu = predict.tolist() + has_finished = False + + # iterate every accepted token and check if req has finished after append the token + # should be checked BEFORE free kv cache slots + for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): + new_accept_index_ = [] + for j, idx in enumerate(accept_index_row): + if idx == -1: + break + id = predict_cpu[idx] + # if not found_finished: + req.output_ids.append(id) + finished_extend_len[req.rid] = j + 1 + req.check_finished() + if req.finished(): + has_finished = True + # set all tokens after finished token to -1 and break + accept_index[i, j + 1 :] = -1 + break + else: + new_accept_index_.append(idx) + if not req.finished(): + new_accept_index.extend(new_accept_index_) + unfinished_index.append(i) + req.spec_verify_ct += 1 + accept_length = (accept_index != -1).sum(dim=1) - 1 + + accept_index = accept_index[accept_index != -1] + accept_length_cpu = accept_length.tolist() + verified_id = predict[accept_index] + + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + batch.token_to_kv_pool.free(mem_need_free_idx) + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) + + draft_input = EagleDraftInput() + if len(new_accept_index) > 0: + new_accept_index = torch.tensor(new_accept_index, device="cuda") + draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] + draft_input.verified_id = predict[new_accept_index] + draft_input.accept_length = accept_length[unfinished_index] + draft_input.accept_length_cpu = [ + accept_length_cpu[i] for i in unfinished_index + ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens + + logits_output.next_token_logits = logits_output.next_token_logits[accept_index] + return ( + draft_input, + logits_output, + verified_id, + finished_extend_len, + accept_length_cpu, + ) @triton.jit @@ -136,21 +484,57 @@ def assign_req_to_token_pool( load_offset += BLOCK_SIZE +@triton.jit +def assign_draft_cache_locs( + req_pool_indices, + req_to_token, + seq_lens, + out_cache_loc, + pool_len: tl.constexpr, + topk: tl.constexpr, + speculative_num_steps: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(seq_lens + pid) + kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + + num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) + for i in range(num_loop): + save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + + @triton.jit def generate_draft_decode_kv_indices( req_pool_indices, req_to_token, paged_kernel_lens, kv_indices, - iters: tl.constexpr, + kv_indptr, + positions, + num_seqs: tl.constexpr, topk: tl.constexpr, pool_len: tl.constexpr, + kv_indices_stride: tl.constexpr, + kv_indptr_stride: tl.constexpr, bs_upper: tl.constexpr, iter_upper: tl.constexpr, + num_tokens_upper: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 128 - bid = tl.program_id(axis=0) - topk_id = tl.program_id(axis=1) + iters = tl.program_id(axis=0) + bid = tl.program_id(axis=1) + topk_id = tl.program_id(axis=2) + + kv_indices += kv_indices_stride * iters + kv_indptr += kv_indptr_stride * iters + iters += 1 load_offset = tl.arange(0, bs_upper) seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) @@ -176,473 +560,73 @@ def generate_draft_decode_kv_indices( ) tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) + # Update kv_indptr + bs_offset = tl.arange(0, num_tokens_upper) -class EAGLEDraftInput(SpecInfo): - def __init__(self): - self.prev_mode = ForwardMode.DECODE + zid = bid * topk + topk_id + if zid == 0: + zid = num_seqs * topk + positions = tl.load(positions + bs_offset, mask=bs_offset < zid) + base = tl.sum(positions) + tl.store(kv_indptr + zid, base + zid * iters) - self.scores: torch.Tensor = None - self.score_list: List[torch.Tensor] = [] - self.token_list: List[torch.Tensor] = [] - self.origin_score_list: List[torch.Tensor] = [] # used for sampling - self.parents_list: List[torch.Tensor] = [] - self.cache_list: List[torch.Tenor] = [] - self.iter = 0 - # shape: (b, hidden_size) - self.hidden_states: torch.Tensor = None - # shape: (b,) - self.verified_id: torch.Tensor = None - # shape: (b, vocab_size) - self.sample_output: torch.Tensor = None +@torch.compile +def select_top_k_tokens( + i: int, + topk_p: torch.Tensor, + topk_index: torch.Tensor, + hidden_states: torch.Tensor, + scores: torch.Tensor, + topk: int, +): + if i == 0: + # The first step after extend + input_ids = topk_index.flatten() + hidden_states = hidden_states.repeat_interleave(topk, dim=0) + scores = topk_p # shape: (b, topk) - self.positions: torch.Tensor = None - self.accept_length: torch.Tensor = None - self.accept_length_cpu: List[int] = None - - def load_server_args(self, server_args: ServerArgs): - self.topk: int = server_args.speculative_eagle_topk - self.num_verify_token: int = server_args.speculative_num_draft_tokens - self.spec_steps = server_args.speculative_num_steps - - def prepare_for_extend(self, batch: ScheduleBatch): - req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) - out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - batch.out_cache_loc = out_cache_loc - - pt = 0 - for i, req in enumerate(batch.reqs): - req.req_pool_idx = req_pool_indices[i] - pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) - assert seq_len - pre_len == req.extend_input_len - - if pre_len > 0: - batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - :pre_len - ] = req.prefix_indices - - batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( - out_cache_loc[pt : pt + req.extend_input_len] - ) - - pt += req.extend_input_len - - # TODO: support batching inputs - assert len(batch.extend_lens) == 1 - batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) - - def filter_batch( - self, - new_indices: torch.Tensor, - ): - self.sample_output = self.sample_output[: len(new_indices)] - self.hidden_states = self.hidden_states[: len(new_indices)] - self.verified_id = self.verified_id[: len(new_indices)] - - def prepare_for_decode(self, batch: ScheduleBatch): - prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) - top = torch.topk(prob, self.topk, dim=-1) - topk_index, topk_p = ( - top.indices, - top.values, - ) # shape: (b * top_k, top_k) or (b, top_k) - - if self.prev_mode.is_decode(): - scores = torch.mul( - self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) - ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) - topk_cs = torch.topk( - scores.flatten(start_dim=1), self.topk, dim=-1 - ) # (b, topk) - topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - - selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange( - 0, batch.batch_size() * self.topk, step=self.topk, device="cuda" - ).repeat_interleave(self.topk) - - batch.spec_info.hidden_states = batch.spec_info.hidden_states[ - selected_input_index, : - ] - - topk_index = topk_index.reshape(-1, self.topk**2) - batch.input_ids = torch.gather( - topk_index, index=topk_cs_index, dim=1 - ).flatten() - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) - - self.scores = topk_cs_p - self.score_list.append(scores) # (b, topk, topk) - self.token_list.append(topk_index) # (b, topk * topk) - self.origin_score_list.append(topk_p.reshape(topk_index.shape)) - self.parents_list.append( - topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) - ) # shape: (b, topk) - else: - # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND - batch.spec_info.hidden_states = ( - batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0) - ) - - batch.input_ids = topk_index.flatten() - batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) - - self.scores = topk_p # shape: (b, topk) - self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk) - self.token_list.append(topk_index) # shape: (b, topk) - self.origin_score_list.append(topk_p) - self.parents_list.append( - torch.arange(-1, self.topk, dtype=torch.long, device="cuda") - .unsqueeze(0) - .repeat(self.scores.shape[0], 1) - ) # shape: (b, topk + 1) - self.cache_list.append(batch.out_cache_loc) - self.positions = ( - batch.seq_lens[:, None] - + torch.full( - [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long - ) - ).flatten() - - bs = len(batch.seq_lens) - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens + self.topk * self.iter, - batch.seq_lens + self.topk * (self.iter + 1), - batch.out_cache_loc, - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), - ) - self.iter += 1 - - def prepare_extend_after_decode(self, batch: ScheduleBatch): - batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) - accept_length_cpu = batch.spec_info.accept_length_cpu - batch.extend_lens = [x + 1 for x in accept_length_cpu] - batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend - seq_lens_cpu = batch.seq_lens.tolist() - - pt = 0 - i = 0 - for req in batch.reqs: - if req.finished(): - continue - # assert seq_len - pre_len == req.extend_input_len - input_len = batch.extend_lens[i] - seq_len = seq_lens_cpu[i] - batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - seq_len - input_len : seq_len - ] = batch.out_cache_loc[pt : pt + input_len] - pt += input_len - i += 1 - assert pt == batch.out_cache_loc.shape[0] - - self.positions = torch.empty_like(self.verified_id) - new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) - self.accept_length.add_(1) - - create_extend_spec_info[(self.accept_length.numel(),)]( - self.verified_id, - batch.seq_lens, - self.accept_length, - torch.cumsum(self.accept_length, axis=0, dtype=torch.int), - self.positions, - new_verified_id, - triton.next_power_of_2(self.spec_steps + 1), + tree_info = ( + topk_p.unsqueeze(1), # shape: (b, 1, topk) + topk_index, # shape: (b, topk) + torch.arange(-1, topk, dtype=torch.long, device="cuda") + .unsqueeze(0) + .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) ) - batch.seq_lens_sum = sum(seq_lens_cpu) - batch.input_ids = self.verified_id - self.verified_id = new_verified_id + else: + # The later decode steps + expand_scores = torch.mul( + scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) - def prepare_for_verify(self, batch: ScheduleBatch): - score_list = torch.cat(self.score_list, dim=1).flatten( - 1 - ) # b, n, topk; n= 1+(self.iter-1)*self.topk - ss_token_list = torch.cat( - self.token_list, dim=1 - ) # b, (self.topk+(self.iter-1)*self.topk) - origin_token_list = torch.cat(self.origin_score_list, dim=1) - top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) - top_scores_index = top_scores.indices - top_scores_index = torch.sort(top_scores_index).values + topk_cs_p, topk_cs_index = fast_topk( + expand_scores.flatten(start_dim=1), topk, dim=-1 + ) # (b, topk) + scores = topk_cs_p # shape: (b, topk) - draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) - scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) - draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) - parent_list = torch.cat(self.parents_list[:-1], dim=1) + topk_index = topk_index.reshape(-1, topk**2) + input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() - tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( - parent_list, - top_scores_index, - batch.seq_lens, - self.topk, - self.iter - 1, - self.num_verify_token, + selected_input_index = topk_cs_index.flatten() // topk + torch.arange( + 0, hidden_states.shape[0], step=topk, device="cuda" + ).repeat_interleave(topk) + hidden_states = hidden_states[selected_input_index, :] + + tree_info = ( + expand_scores, # shape: (b, topk, topk) + topk_index, # shape: (b, topk * topk) + topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) ) - return EagleVerifyInput( - draft_tokens.flatten(), - scores.flatten(), - tree_mask, - position, - retrive_index, - retrive_cum_len, - self.num_verify_token, - ) - - def generate_attn_arg_decode( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token: torch.Tensor, - ): - seq_num = req_pool_indices.numel() - bs = self.topk * req_pool_indices.numel() - seq_len = self.positions.reshape(-1).contiguous() - - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) - total_len = torch.sum(paged_kernel_lens).item() - - kv_indices = torch.empty( - (total_len * self.topk + seq_num * self.iter * self.topk,), - dtype=torch.int32, - device="cuda", - ) - - generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)]( - req_pool_indices, - req_to_token, - paged_kernel_lens, - kv_indices, - self.iter, - self.topk, - req_to_token.shape[1], - triton.next_power_of_2(seq_num), - triton.next_power_of_2(self.spec_steps), - ) - return bs, kv_indices, cum_kv_seq_len - - def clear_draft_cache(self, batch): - draft_cache = torch.cat(self.cache_list, dim=0) - batch.token_to_kv_pool.free(draft_cache) - - def generate_attn_arg_prefill( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token: torch.Tensor, - ): - bs = self.accept_length.numel() - qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) - - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") - - create_flashinfer_kv_indices_triton[(bs,)]( - req_to_token, - req_pool_indices, - paged_kernel_lens, - cum_kv_seq_len, - None, - kv_indices, - req_to_token.size(1), - ) - - return kv_indices, cum_kv_seq_len, qo_indptr, None - - def merge_batch(self, spec_info: EAGLEDraftInput): - if self.hidden_states is None: - self.hidden_states = spec_info.hidden_states - self.verified_id = spec_info.verified_id - self.sample_output = spec_info.sample_output - self.prev_mode = spec_info.prev_mode - return - if spec_info.hidden_states is None: - return - self.hidden_states = torch.cat( - [self.hidden_states, spec_info.hidden_states], axis=0 - ) - self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) - self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) + return input_ids, hidden_states, scores, tree_info -class EagleVerifyInput(SpecInfo): - def __init__( - self, - draft_token: torch.Tensor, - draft_score: torch.Tensor, - tree_mask: torch.Tensor, - positions: torch.Tensor, - retrive_index: torch.Tensor, - retrive_cum_len: torch.Tensor, - draft_token_num: int, - ): - self.draft_token = draft_token - self.draft_score = draft_score - self.custom_mask = tree_mask - self.positions = positions - self.retrive_index = retrive_index - self.retrive_cum_len = retrive_cum_len - self.draft_token_num = draft_token_num - - def prepare_for_verify(self, batch: ScheduleBatch): - batch.input_ids = self.draft_token - batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - bs = batch.seq_lens.numel() - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + self.draft_token_num, - batch.out_cache_loc, - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), - ) - - def generate_attn_arg_prefill( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token: torch.Tensor, - ): - batch_size = len(req_pool_indices) - qo_indptr = torch.arange( - 0, - (1 + batch_size) * self.draft_token_num, - step=self.draft_token_num, - dtype=torch.int32, - device="cuda", - ) - - cum_kv_seq_len = torch.zeros( - (batch_size + 1,), dtype=torch.int32, device="cuda" - ) - - paged_kernel_lens = paged_kernel_lens + self.draft_token_num - cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") - - create_flashinfer_kv_indices_triton[(batch_size,)]( - req_to_token, - req_pool_indices, - paged_kernel_lens, - cum_kv_seq_len, - None, - kv_indices, - req_to_token.size(1), - ) - return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask - - def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: - predict = torch.argmax(logits_output.next_token_logits, dim=-1) - predict = torch.cat( - [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1 - ) - draft_token = torch.cat( - [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")], - dim=-1, - ) - target_predict = predict[self.retrive_index] - candidates = draft_token[self.retrive_index] - # logits = logits_output.next_token_logits[self.retrive_index] - # target_predict = torch.argmax(logits[:, :-1], dim=-1) - accept_mask = candidates[:, 1:] == target_predict[:, :-1] - accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) - bs = self.retrive_cum_len.numel() - 1 - - max_draft_len = self.retrive_index.shape[-1] - accept_index = torch.full( - (bs, max_draft_len), -1, dtype=torch.long, device="cuda" - ) - accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") - extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") - eagle_verify_retrive[(bs,)]( - self.retrive_index.contiguous(), - accept_mask.contiguous(), - self.retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_draft_len, - self.draft_token_num, - triton.next_power_of_2(max_draft_len), - ) - - draft_input = EAGLEDraftInput() - new_accept_index = [] - unfinished_index = [] - finished_extend_len = {} # {rid:accept_length + 1} - accept_index_cpu = accept_index.tolist() - predict_cpu = predict.tolist() - has_finished = False - - # iterate every accepted token and check if req has finished after append the token - # should be checked BEFORE free kv cache slots - for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): - new_accept_index_ = [] - for j, idx in enumerate(accept_index_row): - if idx == -1: - break - id = predict_cpu[idx] - # if not found_finished: - req.output_ids.append(id) - finished_extend_len[req.rid] = j + 1 - req.check_finished() - if req.finished(): - has_finished = True - # set all tokens after finished token to -1 and break - accept_index[i, j + 1 :] = -1 - break - else: - new_accept_index_.append(idx) - if not req.finished(): - new_accept_index.extend(new_accept_index_) - unfinished_index.append(i) - req.spec_verify_ct += 1 - accept_length = (accept_index != -1).sum(dim=1) - 1 - - accept_index = accept_index[accept_index != -1] - accept_length_cpu = accept_length.tolist() - verified_id = predict[accept_index] - - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[accept_index] = False - mem_need_free_idx = batch.out_cache_loc[evict_mask] - batch.token_to_kv_pool.free(mem_need_free_idx) - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + accept_length + 1, - batch.out_cache_loc[accept_index], - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), - ) - batch.seq_lens.add_(accept_length + 1) - - if len(new_accept_index) > 0: - new_accept_index = torch.tensor(new_accept_index, device="cuda") - draft_input.verified_id = predict[new_accept_index] - draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] - draft_input.accept_length = accept_length[unfinished_index] - draft_input.accept_length_cpu = [ - accept_length_cpu[i] for i in unfinished_index - ] - if has_finished: - draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] - else: - draft_input.seq_lens_for_draft_extend = batch.seq_lens - - logits_output.next_token_logits = logits_output.next_token_logits[accept_index] - return ( - draft_input, - logits_output, - verified_id, - finished_extend_len, - accept_length_cpu, - ) +def fast_topk(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + max_value, max_index = torch.max(values, dim=dim) + return max_value.unsqueeze(1), max_index.unsqueeze(1) + else: + # Use topk for efficiency with larger k values + return torch.topk(values, topk, dim=dim) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 06a4372fc..b5a3de6ca 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,3 +1,5 @@ +import logging +import time from typing import List, Optional, Union import torch @@ -12,8 +14,18 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.eagle_utils import EAGLEDraftInput -from sglang.srt.utils import rank0_print +from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( + EAGLEDraftCudaGraphRunner, +) +from sglang.srt.speculative.eagle_utils import ( + EagleDraftInput, + EagleVerifyInput, + assign_draft_cache_locs, + fast_topk, + select_top_k_tokens, +) + +logger = logging.getLogger(__name__) class EAGLEWorker(TpModelWorker): @@ -40,41 +52,47 @@ class EAGLEWorker(TpModelWorker): is_draft_worker=True, ) self.target_worker = target_worker - self.server_args = server_args self.finish_extend_len = [] + # Parse arguments + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.server_args = server_args + # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph - self.model_runner.init_cuda_graphs() - def forward_draft_decode(self, batch: ScheduleBatch): - batch.spec_info.prepare_for_decode(batch) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) + # Create multi-step attn backends and cuda graph runners + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) - def forward_draft_extend(self, batch: ScheduleBatch): - self._set_mem_pool(batch, self.model_runner) - batch.spec_info.prepare_for_extend(batch) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + self.draft_attn_backend = FlashInferMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) + self.model_runner.draft_attn_backend = self.draft_attn_backend + self.init_cuda_graphs() + + def init_cuda_graphs(self): + """Capture cuda graphs.""" + self.cuda_graph_runner = None + + if self.server_args.disable_cuda_graph: + return + + tic = time.time() + logger.info("Capture cuda graph begin. This can take up to several minutes.") + self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def forward_batch_speculative_generation(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): # Draft - self._set_mem_pool(batch, self.model_runner) - for i in range(self.server_args.speculative_num_steps): - self.forward_draft_decode(batch) - batch.spec_info.clear_draft_cache(batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + spec_info: EagleVerifyInput = self.draft(batch) # Verify ( @@ -84,8 +102,7 @@ class EAGLEWorker(TpModelWorker): self.finish_extend_len, accept_length_cpu, model_worker_batch, - ) = self.verify(batch) - next_draft_input.load_server_args(self.server_args) + ) = self.verify(batch, spec_info) batch.spec_info = next_draft_input # if it is None, means all requsets are finished if batch.spec_info.verified_id is not None: @@ -107,29 +124,145 @@ class EAGLEWorker(TpModelWorker): ) # Forward with the draft model. - spec_info = EAGLEDraftInput() - spec_info.load_server_args(self.server_args) - spec_info.hidden_states = logits_output.hidden_states - spec_info.verified_id = next_token_ids - batch.spec_info = spec_info + batch.spec_info = EagleDraftInput( + hidden_states=logits_output.hidden_states, + verified_id=next_token_ids, + ) self.forward_draft_extend(batch) return logits_output, next_token_ids, model_worker_batch, 0 - def verify(self, batch: ScheduleBatch): - verify_input = batch.spec_info.prepare_for_verify(batch) - verify_input.prepare_for_verify(batch) + def draft(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + + # Parse args + num_seqs = batch.batch_size() + spec_info = batch.spec_info + + # Allocate cache locations + out_cache_loc = batch.alloc_token_slots( + num_seqs * self.topk * self.speculative_num_steps + ) + assign_draft_cache_locs[(num_seqs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + self.topk, + self.speculative_num_steps, + ) + + batch.out_cache_loc = out_cache_loc + batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) + + # Get forward batch + spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( + forward_batch + ) + + if can_cuda_graph: + score_list, token_list, parents_list = self.cuda_graph_runner.replay( + forward_batch + ) + else: + # Initialize attention backend + self.draft_attn_backend.init_forward_metadata(forward_batch) + + # Run forward steps + score_list, token_list, parents_list = self.draft_forward(forward_batch) + + ret = EagleVerifyInput.create( + spec_info.verified_id, + score_list, + token_list, + parents_list, + batch.seq_lens, + batch.seq_lens_sum, + self.topk, + self.speculative_num_steps, + self.server_args.speculative_num_draft_tokens, + ) + + # Free cache locations + batch.token_to_kv_pool.free(out_cache_loc) + self._set_mem_pool(batch, self.target_worker.model_runner) + return ret + + def draft_forward(self, forward_batch: ForwardBatch): + # Parse args + spec_info = forward_batch.spec_info + out_cache_loc = forward_batch.out_cache_loc + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + + # Return values + score_list: List[torch.Tensor] = [] + token_list: List[torch.Tensor] = [] + parents_list: List[torch.Tensor] = [] + + # Forward multiple steps + scores = None + for i in range(self.speculative_num_steps): + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + i, topk_p, topk_index, hidden_states, scores, self.topk + ) + score_list.append(tree_info[0]) + token_list.append(tree_info[1]) + parents_list.append(tree_info[2]) + + # Set inputs + forward_batch.input_ids = input_ids + forward_batch.out_cache_loc = out_cache_loc[ + forward_batch.batch_size + * self.topk + * i : forward_batch.batch_size + * self.topk + * (i + 1) + ] + forward_batch.positions.add_(1) + forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] + spec_info.hidden_states = hidden_states + + # Run forward + logits_output = self.model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + hidden_states = logits_output.hidden_states + + return score_list, token_list, parents_list + + def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + spec_info.prepare_for_verify(batch) batch.forward_mode = ForwardMode.TARGET_VERIFY - batch.spec_info = verify_input - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch() logits_output, _ = self.target_worker.forward_batch_generation( model_worker_batch, skip_sample=True ) - verify_input.hidden_states = logits_output.hidden_states - res = verify_input.verify(batch, logits_output) + spec_info.hidden_states = logits_output.hidden_states + res = spec_info.verify(batch, logits_output) batch.forward_mode = ForwardMode.DECODE return res + (model_worker_batch,) + def forward_draft_extend(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._set_mem_pool(batch, self.target_worker.model_runner) + def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.token_to_kv_pool = runner.token_to_kv_pool batch.req_to_token_pool = runner.req_to_token_pool @@ -139,7 +272,7 @@ class EAGLEWorker(TpModelWorker): self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND - batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) @@ -155,13 +288,10 @@ class EAGLEWorker(TpModelWorker): def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ): - sample_output = torch.softmax( - logits_output.next_token_logits, dim=-1 - ) # TODO(kavioyu): Support more sampling methods + probs = torch.softmax(logits_output.next_token_logits, dim=-1) spec_info = forward_batch.spec_info - spec_info.sample_output = sample_output + spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1) spec_info.hidden_states = logits_output.hidden_states - spec_info.prev_mode = forward_batch.forward_mode # Don't support prefix share now. def finish_request(self, reqs: Union[Req, List[Req]]):