diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index fabdcb9e4..d0ab5ca82 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -55,6 +55,25 @@ class AttentionBackend(ABC): """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() + def get_verify_buffers_to_fill_after_draft(self): + """ + Return buffers of verify attention kernels that needs to be filled after draft. + + Typically, these are tree mask and position buffers. + """ + return [None, None] + + def update_verify_buffers_to_fill_after_draft( + self, spec_info: SpecInput, cuda_graph_bs: Optional[int] + ): + """ + Update the buffers returned by get_verify_fill_after_draft_buffers if needed. + + Here, we need to redo the computation of all metadata of the attention backend + that depends on tree mask and position buffers. + """ + raise NotImplementedError() + def forward( self, q: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 520792119..9f09a268a 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -29,7 +29,6 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import ( get_int_env_var, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1ef9274e5..a483670db 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -162,6 +162,8 @@ class TritonAttnBackend(AttentionBackend): # Initialize forward metadata self.forward_metadata: ForwardMetadata = None + self.cuda_graph_custom_mask = None + def get_num_kv_splits( self, num_kv_splits: torch.Tensor, @@ -755,6 +757,19 @@ class TritonAttnBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): return 1 + def get_verify_buffers_to_fill_after_draft(self): + """ + Return buffers for verify attention kernels that needs to be filled after draft. + + Typically, these are tree mask and position buffers. + """ + return [self.cuda_graph_custom_mask, None] + + def update_verify_buffers_to_fill_after_draft( + self, spec_info: SpecInput, cuda_graph_bs: Optional[int] + ): + pass + def forward_extend( self, q: torch.Tensor, diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 6a7d330b5..dfacd858c 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -384,6 +384,7 @@ class LogitsProcessor(nn.Module): if ( logits_metadata.forward_mode.is_decode_or_idle() or logits_metadata.forward_mode.is_target_verify() + or logits_metadata.forward_mode.is_draft_extend_v2() ): pruned_states = hidden_states if aux_hidden_states is not None: diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py index f91212b75..f73c064c5 100644 --- a/python/sglang/srt/managers/overlap_utils.py +++ b/python/sglang/srt/managers/overlap_utils.py @@ -1,11 +1,18 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch -from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.utils import get_compiler_backend +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ModelWorkerBatch + from sglang.srt.managers.scheduler import GenerationBatchResult + from sglang.srt.speculative.eagle_info import EagleDraftInput + from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + @torch.compile(dynamic=True, backend=get_compiler_backend()) def _resolve_future_token_ids(input_ids, future_token_ids_map): @@ -27,6 +34,7 @@ class FutureMap: self, max_running_requests: int, device: torch.device, + spec_algo: Optional[SpeculativeAlgorithm] = None, ): self.future_ct = 0 # A factor of 3 is used to avoid collision in the circular buffer. @@ -34,9 +42,51 @@ class FutureMap: # A factor of 5 is used to ensure the buffer is large enough. self.future_buffer_len = max_running_requests * 5 self.device = device + self.spec_algo = spec_algo + self.buf_initialized = False - self.token_ids_buf = torch.empty( - (self.future_buffer_len,), dtype=torch.int64, device=self.device + if self.spec_algo.is_none(): + self.token_ids_buf = torch.empty( + (self.future_buffer_len,), dtype=torch.int64, device=self.device + ) + + def _lazy_init_buf(self, draft_input: EagleDraftInput): + if self.buf_initialized or not self.spec_algo.is_eagle(): + return + + self.buf_initialized = True + + # get the template for each tensor + topk_p0 = draft_input.topk_p[0] + topk_index0 = draft_input.topk_index[0] + hidden_states0 = draft_input.hidden_states[0] + verified_id0 = draft_input.verified_id[0] + new_seq_lens0 = draft_input.new_seq_lens[0] + + self.topk_p_buf = torch.empty( + (self.future_buffer_len, *topk_p0.shape), + dtype=topk_p0.dtype, + device=self.device, + ) + self.topk_index_buf = torch.empty( + (self.future_buffer_len, *topk_index0.shape), + dtype=topk_index0.dtype, + device=self.device, + ) + self.hidden_states_buf = torch.empty( + (self.future_buffer_len, *hidden_states0.shape), + dtype=hidden_states0.dtype, + device=self.device, + ) + self.verified_id_buf = torch.empty( + (self.future_buffer_len, *verified_id0.shape), + dtype=verified_id0.dtype, + device=self.device, + ) + self.new_seq_lens_buf = torch.empty( + (self.future_buffer_len, *new_seq_lens0.shape), + dtype=new_seq_lens0.dtype, + device=self.device, ) def alloc_future_indices(self, bs: int) -> FutureIndices: @@ -49,7 +99,32 @@ class FutureMap: return FutureIndices(indices=indices, interval=slice(start, end)) def resolve_future(self, model_worker_batch: ModelWorkerBatch): - _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf) + if self.spec_algo.is_eagle(): + # TODO(lsyin): write future indices into spec_info.future_indices + draft_input: EagleDraftInput = model_worker_batch.spec_info + if draft_input is None: + # FIXME(lsyin): No future exists, only for prefill batch, not compatible with mixed mode + return + indices = draft_input.future_indices.indices + draft_input.topk_p = self.topk_p_buf[indices] + draft_input.topk_index = self.topk_index_buf[indices] + draft_input.hidden_states = self.hidden_states_buf[indices] + draft_input.verified_id = self.verified_id_buf[indices] + draft_input.new_seq_lens = self.new_seq_lens_buf[indices] + else: + _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf) - def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor): - self.token_ids_buf[future_indices.interval] = next_token_ids + def store_to_map( + self, future_indices: FutureIndices, batch_result: GenerationBatchResult + ): + intv = future_indices.interval + if self.spec_algo.is_eagle(): + draft_input: EagleDraftInput = batch_result.next_draft_input + self._lazy_init_buf(draft_input) + self.topk_p_buf[intv] = draft_input.topk_p + self.topk_index_buf[intv] = draft_input.topk_index + self.hidden_states_buf[intv] = draft_input.hidden_states + self.verified_id_buf[intv] = draft_input.verified_id + self.new_seq_lens_buf[intv] = draft_input.new_seq_lens + else: + self.token_ids_buf[intv] = batch_result.next_token_ids diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1d42cd8f7..fec93f66f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -61,8 +61,12 @@ from sglang.srt.mem_cache.allocator import ( ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache -from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend -from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool +from sglang.srt.mem_cache.common import ( + alloc_for_decode, + alloc_for_extend, + alloc_token_slots, +) +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats @@ -71,6 +75,7 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import flatten_nested_list +from sglang.srt.utils.common import next_power_of_2 if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig @@ -1067,6 +1072,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def is_empty(self): return len(self.reqs) == 0 + def allocate_for_eagle_v2(self): + from sglang.srt.speculative.eagle_info import EagleDraftInput + from sglang.srt.speculative.spec_utils import assign_req_to_token_pool + + bs = self.batch_size() + + assert self.spec_info.is_draft_input() + draft_input: EagleDraftInput = self.spec_info + + # FIXME(lsyin): now implementation does not enable over-allocation + # Now seq_lens and allocate_lens are correct + self.maybe_wait_verify_done() + + new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE + num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item() + out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens) + + assign_req_to_token_pool[(bs,)]( + self.req_pool_indices, + self.req_to_token_pool.req_to_token, + draft_input.allocate_lens, + new_allocate_lens, + out_cache_loc, + self.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + draft_input.allocate_lens = new_allocate_lens + + # FIXME(lsyin): remove seq_lens_sum calculation + self.seq_lens_cpu = self.seq_lens.cpu() + self.seq_lens_sum = self.seq_lens_cpu.sum().item() + def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] @@ -1507,15 +1544,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.model_config.vocab_size, ) + @property + def is_v2_eagle(self): + # FIXME: finally deprecate is_v2_eagle + return self.enable_overlap and self.spec_algorithm.is_eagle() + def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE bs = len(self.reqs) - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.is_v2_eagle: + # FIXME(lsyin): make this sync optional + self.allocate_for_eagle_v2() + + if not self.spec_algorithm.is_none(): # if spec decoding is used, the decode batch is prepared inside # `forward_batch_speculative_generation` after running draft models. return @@ -1566,11 +1608,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.orig_seq_lens.add_(1) self.seq_lens_sum += bs + def maybe_wait_verify_done(self): + if self.is_v2_eagle: + from sglang.srt.speculative.eagle_info import EagleDraftInput + + draft_input: EagleDraftInput = self.spec_info + if draft_input.verify_done is not None: + draft_input.verify_done.synchronize() + def filter_batch( self, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, keep_indices: Optional[List[int]] = None, ): + # FIXME(lsyin): used here to get the correct seq_lens + # The batch has been launched but we need it verified to get correct next batch info + self.maybe_wait_verify_done() + if keep_indices is None: if isinstance(chunked_req_to_exclude, Req): chunked_req_to_exclude = [chunked_req_to_exclude] @@ -1633,6 +1687,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) def merge_batch(self, other: "ScheduleBatch"): + # NOTE: in v2 eagle mode, we do not need wait verify here because + # 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future + # 2) other batch is always decode, which is finished in previous step + # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. @@ -1757,6 +1815,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, is_extend_in_batch=self.is_extend_in_batch, is_prefill_only=self.is_prefill_only, + seq_lens_cpu=self.seq_lens_cpu, + enable_overlap=self.enable_overlap, ) def _evict_tree_cache_if_needed(self, num_tokens: int): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d4c8d5902..5a7555a86 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -148,13 +148,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.model_executor.forward_batch_info import ( - ForwardBatch, - ForwardMode, - PPProxyTensors, -) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( process_tracing_init, @@ -219,6 +216,14 @@ class GenerationBatchResult: forward_batch: Optional[ForwardBatch] = None future_indices: Optional[FutureIndices] = None + # FIXME(lsyin): maybe move to ? + # sync path: forward stream -> output processor + accept_lens: Optional[torch.Tensor] = None + last_batch_allocate_lens: Optional[torch.Tensor] = None + + # relay path: forward stream -> next step forward + next_draft_input: Optional[EagleDraftInput] = None + def copy_to_cpu(self, return_logprob: bool = False): """Copy tensors to CPU in overlap scheduling. Only the tensors which are needed for processing results are copied, @@ -238,6 +243,15 @@ class GenerationBatchResult: "cpu", non_blocking=True ) self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True) + + if self.accept_lens is not None: + self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) + + if self.last_batch_allocate_lens is not None: + self.last_batch_allocate_lens = self.last_batch_allocate_lens.to( + "cpu", non_blocking=True + ) + self.copy_done.record() @classmethod @@ -273,48 +287,6 @@ class Scheduler( ): """A scheduler that manages a tensor parallel GPU worker.""" - def launch_draft_worker( - self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank - ): - if self.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_worker import EAGLEWorker - - self.draft_worker = EAGLEWorker( - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - server_args=server_args, - nccl_port=port_args.nccl_port, - target_worker=self.tp_worker, - dp_rank=dp_rank, - ) - elif self.spec_algorithm.is_standalone(): - from sglang.srt.speculative.standalone_worker import StandaloneWorker - - self.draft_worker = StandaloneWorker( - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - server_args=server_args, - nccl_port=port_args.nccl_port, - target_worker=self.tp_worker, - dp_rank=dp_rank, - ) - elif self.spec_algorithm.is_ngram(): - from sglang.srt.speculative.ngram_worker import NGRAMWorker - - self.draft_worker = NGRAMWorker( - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - server_args=server_args, - nccl_port=port_args.nccl_port, - target_worker=self.tp_worker, - dp_rank=dp_rank, - ) - else: - self.draft_worker = None - def __init__( self, server_args: ServerArgs, @@ -454,6 +426,7 @@ class Scheduler( ) # Launch a draft worker for speculative decoding + self.launch_draft_worker( gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank ) @@ -683,6 +656,51 @@ class Scheduler( ] ) + def launch_draft_worker( + self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank + ): + if self.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_worker import EAGLEWorker + from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2 + + WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker + + self.draft_worker = WorkerClass( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + elif self.spec_algorithm.is_standalone(): + from sglang.srt.speculative.standalone_worker import StandaloneWorker + + self.draft_worker = StandaloneWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + elif self.spec_algorithm.is_ngram(): + from sglang.srt.speculative.ngram_worker import NGRAMWorker + + self.draft_worker = NGRAMWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + else: + self.draft_worker = None + def init_deterministic_inference_config(self): """Initialize deterministic inference configuration for different attention backends.""" if not self.server_args.enable_deterministic_inference: @@ -965,7 +983,9 @@ class Scheduler( self.device ).stream(self.copy_stream) - self.future_map = FutureMap(self.max_running_requests, self.device) + self.future_map = FutureMap( + self.max_running_requests, self.device, self.spec_algorithm + ) self.batch_record_buf = [None] * 2 self.batch_record_ct = 0 @@ -2096,7 +2116,7 @@ class Scheduler( batch_or_worker_batch = batch - if self.spec_algorithm.is_none(): + if self.enable_overlap or self.spec_algorithm.is_none(): # FIXME(lsyin): remove this if and finally unify the abstraction batch_or_worker_batch = batch.get_model_worker_batch() @@ -2120,39 +2140,49 @@ class Scheduler( if batch.sampling_info.grammars is not None: model_worker_batch.delay_sample_launch = True batch_result = self.model_worker.forward_batch_generation( - batch_or_worker_batch + model_worker_batch ) # FIXME(lsyin): maybe move this to forward_batch_generation batch_result.copy_done = torch.get_device_module( self.device ).Event() if not model_worker_batch.delay_sample_launch: - self.future_map.store_to_map( - future_indices, batch_result.next_token_ids - ) + self.future_map.store_to_map(future_indices, batch_result) batch_result.copy_to_cpu() else: batch_result.future_indices = future_indices # FIXME(lsyin): move this assignment elsewhere - maybe_future_next_token_ids = -future_indices.indices + future_indices_or_next_token_ids = -future_indices.indices + + if batch.is_v2_eagle: + # FIXME(lsyin): tmp code for eagle v2 + # We only keep future indices for next draft input + + batch.spec_info = batch_result.next_draft_input + batch.spec_info.future_indices = future_indices + + # batch.spec_info = EagleDraftInput( + # future_indices=future_indices, + # verify_done=batch_result.next_draft_input.verify_done, + # # FIXME(lsyin): remove the allocate_lens in EagleDraftInput + # allocate_lens=batch_result.next_draft_input.allocate_lens, + # ) + + # The future value, usually for next batch preparation + # Current implementation strictly synchronizes the seq_lens + batch.seq_lens = batch_result.next_draft_input.new_seq_lens else: batch_result = self.model_worker.forward_batch_generation( batch_or_worker_batch ) - maybe_future_next_token_ids = batch_result.next_token_ids + future_indices_or_next_token_ids = batch_result.next_token_ids - if not self.spec_algorithm.is_none(): - # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing - self.update_spec_metrics( - batch.batch_size(), batch_result.num_accepted_tokens - ) - - # NOTE: maybe_future_next_token_ids is used in ScheduleBatch, + # NOTE: future_indices_or_next_token_ids is used in ScheduleBatch, # which can probably be replaced by future_indices later [TODO(lsyin)]. # we shall still keep the original outputs, e.g. next_token_ids # in the GenerationBatchOutput for processing after copy_done. - batch.output_ids = maybe_future_next_token_ids + batch.output_ids = future_indices_or_next_token_ids # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that @@ -2200,7 +2230,7 @@ class Scheduler( tmp_result.forward_batch, ) future_indices = tmp_result.future_indices - self.future_map.store_to_map(future_indices, tmp_result.next_token_ids) + self.future_map.store_to_map(future_indices, tmp_result) tmp_result.copy_to_cpu() self.result_queue.appendleft((tmp_batch, tmp_result)) diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 4fa4bfee1..dd92dfbd2 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -69,7 +69,7 @@ class SchedulerMetricsMixin: kv_events_config, self.attn_dp_rank ) - def update_spec_metrics(self, bs: int, num_accepted_tokens: int): + def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int): self.spec_num_total_accepted_tokens += num_accepted_tokens + bs self.spec_num_total_forward_ct += bs self.num_generated_tokens += num_accepted_tokens diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index a224bdc34..ef7a3f54e 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import threading import time from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -200,6 +199,28 @@ class SchedulerOutputProcessorMixin: self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) + def hacky_process_eagle_overlap_result( + self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch + ): + # TODO(lsyin): try use a copy stream to share SMs with forward + # FIXME(lsyin): better organize this token free logic in eagle-overlap + last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist() + accept_lens_cpu = result.accept_lens.tolist() + next_token_ids = result.next_token_ids.tolist() + + predict_tokens = [] + num_draft_tokens = self.draft_worker.speculative_num_draft_tokens + for i, req in enumerate(batch.reqs): + predict_tokens.append( + next_token_ids[ + i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i] + ] + ) + # FIXME(lsyin): move this update elsewhere + req.spec_verify_ct += 1 + + return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens + def process_batch_result_decode( self: Scheduler, batch: ScheduleBatch, @@ -220,6 +241,17 @@ class SchedulerOutputProcessorMixin: next_token_ids = next_token_ids.tolist() if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs.tolist() + elif batch.is_v2_eagle: + ( + last_batch_allocate_lens_cpu, + accept_lens_cpu, + next_token_ids, + ) = self.hacky_process_eagle_overlap_result(result, batch) + result.num_accepted_tokens = sum(accept_lens_cpu) + + # FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result + if not self.spec_algorithm.is_none(): + self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens) self.token_to_kv_pool_allocator.free_group_begin() @@ -227,29 +259,74 @@ class SchedulerOutputProcessorMixin: # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding. # We should ignore using next_token_ids for spec decoding cases. for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + req: Req if req.is_retracted: continue if self.enable_overlap and req.finished(): - # Free the one extra delayed token if self.page_size == 1: - self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) - else: - # Only free when the extra token is in a new page - if ( - len(req.origin_input_ids) + len(req.output_ids) - 1 - ) % self.page_size == 0: + if batch.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_worker_v2 import ( + free_spec_dec_tokens_page_size_1, + ) + + free_spec_dec_tokens_page_size_1( + self.req_to_token_pool, + self.token_to_kv_pool_allocator, + req, + last_batch_allocate_lens_cpu[i], + None, + ) + else: + # Free the one extra delayed token self.token_to_kv_pool_allocator.free( batch.out_cache_loc[i : i + 1] ) + else: + if batch.spec_algorithm.is_eagle(): + # TODO(lsyin): support eagle with page_size > 1 + raise NotImplementedError() + else: + if ( + len(req.origin_input_ids) + len(req.output_ids) - 1 + ) % self.page_size == 0: + # Only free when the extra token is in a new page + self.token_to_kv_pool_allocator.free( + batch.out_cache_loc[i : i + 1] + ) continue if batch.spec_algorithm.is_none(): - # speculative worker will solve the output_ids in speculative decoding req.output_ids.append(next_token_id) + elif batch.is_v2_eagle: + # FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding + # !!!unify the logic here!!! + req.output_ids.extend(next_token_id) req.check_finished() if req.finished(): + if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): + # FIXME(lsyin): fix the messy logic here + # 1) when not overlap (v2 impl), we free the extra tokens in the req + # 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch + from sglang.srt.speculative.eagle_worker_v2 import ( + free_spec_dec_tokens_page_size_1, + ) + + new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1 + # FIXME(lsyin): remove this assert + assert new_seq_len == int( + batch.seq_lens_cpu[i] + accept_lens_cpu[i] + ), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}" + + free_spec_dec_tokens_page_size_1( + self.req_to_token_pool, + self.token_to_kv_pool_allocator, + req, + last_batch_allocate_lens_cpu[i], + new_seq_len, + ) + if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes if not self.decode_offload_manager.offload_kv_cache(req): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 33ac661b9..52a40a371 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -231,12 +231,21 @@ class TpModelWorker: def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, + forward_batch: Optional[ForwardBatch] = None, is_verify: bool = False, + skip_attn_backend_init=False, ) -> GenerationBatchResult: - # update the consumer index of hicache to the running batch - self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) + # FIXME(lsyin): maybe remove skip_attn_backend_init in forward_batch_generation, + # which requires preparing replay to always be in this function - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + if model_worker_batch is not None: + # update the consumer index of hicache to the running batch + self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) + + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + else: + # FIXME(lsyin): unify the interface of forward_batch + assert forward_batch is not None pp_proxy_tensors = None if not self.pp_group.is_first_rank: @@ -248,7 +257,9 @@ class TpModelWorker: if self.pp_group.is_last_rank: logits_output, can_run_cuda_graph = self.model_runner.forward( - forward_batch, pp_proxy_tensors=pp_proxy_tensors + forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + skip_attn_backend_init=skip_attn_backend_init, ) batch_result = GenerationBatchResult( logits_output=logits_output, @@ -290,6 +301,7 @@ class TpModelWorker: pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors, + skip_attn_backend_init=skip_attn_backend_init, ) return GenerationBatchResult( pp_hidden_states_proxy_tensors=pp_proxy_tensors, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 8bfb077f9..d24ce8ae3 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -678,8 +678,9 @@ class CudaGraphRunner: capture_hidden_mode_required_by_forward_batch = ( forward_batch.capture_hidden_mode ) - capture_hidden_mode_required_by_spec_info = getattr( - forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + capture_hidden_mode_required_by_spec_info = ( + getattr(forward_batch.spec_info, "capture_hidden_mode", None) + or CaptureHiddenMode.NULL ) capture_hidden_mode_required_for_returning_hidden_states = ( CaptureHiddenMode.FULL diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 297aef2d2..95239c2f9 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -75,6 +75,8 @@ class ForwardMode(IntEnum): # Used in speculative decoding: extend a batch in the draft model. DRAFT_EXTEND = auto() + DRAFT_EXTEND_V2 = auto() + # Split Prefill for PD multiplexing SPLIT_PREFILL = auto() @@ -107,6 +109,10 @@ class ForwardMode(IntEnum): def is_draft_extend(self): return self == ForwardMode.DRAFT_EXTEND + def is_draft_extend_v2(self): + # For fixed shape logits output in v2 eagle worker + return self == ForwardMode.DRAFT_EXTEND_V2 + def is_extend_or_draft_extend_or_mixed(self): return ( self == ForwardMode.EXTEND diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 67d0a1807..c08046130 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -312,6 +312,7 @@ class ServerArgs: nsa_decode: str = "fa3" # Speculative decoding + enable_beta_spec: bool = False speculative_algorithm: Optional[str] = None speculative_draft_model_path: Optional[str] = None speculative_draft_model_revision: Optional[str] = None @@ -1103,11 +1104,19 @@ class ServerArgs: ) if self.max_running_requests is None: self.max_running_requests = 48 - self.disable_overlap_schedule = True - logger.warning( - "Overlap scheduler is disabled because of using " - "eagle speculative decoding." - ) + + if self.speculative_algorithm == "EAGLE" and self.enable_beta_spec: + self.disable_overlap_schedule = False + logger.warning( + "Beta spec is enabled for eagle speculative decoding and overlap schedule is turned on." + ) + + if not self.enable_beta_spec: + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled because of using eagle3 and standalone speculative decoding." + ) + if self.enable_mixed_chunk: self.enable_mixed_chunk = False logger.warning( @@ -2127,6 +2136,7 @@ class ServerArgs: ) # Speculative decoding + parser.add_argument("--enable-beta-spec", action="store_true") parser.add_argument( "--speculative-algorithm", type=str, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 46ecc1b32..d230cf193 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -1,7 +1,7 @@ import logging from copy import copy from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import ClassVar, List, Optional, Tuple import torch import torch.nn.functional as F @@ -10,6 +10,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor +from sglang.srt.managers.overlap_utils import FutureIndices from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.common import ( @@ -18,16 +19,20 @@ from sglang.srt.mem_cache.common import ( get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.eagle_info_v2 import ( + EagleDraftInputV2Mixin, + EagleVerifyInputV2Mixin, +) from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( SIMULATE_ACC_LEN, TREE_SPEC_KERNEL_AVAILABLE, - _generate_simulated_accept_index, align_evict_mask_to_page_size, assign_req_to_token_pool, create_accept_length_filter, create_extend_after_decode_spec_info, filter_finished_cache_loc_kernel, + generate_simulated_accept_index, get_src_tgt_cache_loc, get_target_cache_loc, ) @@ -47,7 +52,7 @@ logger = logging.getLogger(__name__) @dataclass -class EagleVerifyInput(SpecInput): +class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): draft_token: torch.Tensor custom_mask: torch.Tensor positions: torch.Tensor @@ -338,7 +343,7 @@ class EagleVerifyInput(SpecInput): if SIMULATE_ACC_LEN > 0.0: # Do simulation - accept_index = _generate_simulated_accept_index( + accept_index = generate_simulated_accept_index( accept_index=accept_index, predict=predict, # mutable accept_length=accept_length, # mutable @@ -568,7 +573,7 @@ class EagleVerifyInput(SpecInput): @dataclass -class EagleDraftInput(SpecInput): +class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): # The inputs for decode # shape: (b, topk) topk_p: torch.Tensor = None @@ -598,6 +603,15 @@ class EagleDraftInput(SpecInput): seq_lens_for_draft_extend_cpu: torch.Tensor = None req_pool_indices_for_draft_extend: torch.Tensor = None + # Inputs for V2 overlap worker + future_indices: Optional[FutureIndices] = None + allocate_lens: Optional[torch.Tensor] = None + new_seq_lens: Optional[torch.Tensor] = None + verify_done: Optional[torch.cuda.Event] = None + + # FIXME(lsyin): remove this hack + ALLOC_LEN_PER_DECODE: ClassVar[int] = None + def __post_init__(self): super().__init__(SpecInputType.EAGLE_DRAFT) @@ -703,6 +717,11 @@ class EagleDraftInput(SpecInput): return kv_indices, cum_kv_seq_len, qo_indptr, None def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + if self.future_indices is not None: + self.future_indices.indices = self.future_indices.indices[new_indices] + self.allocate_lens = self.allocate_lens[new_indices] + return + if has_been_filtered: # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index` # therefore, we don't need to filter the batch again in scheduler @@ -722,6 +741,18 @@ class EagleDraftInput(SpecInput): self.verified_id = self.verified_id[new_indices] def merge_batch(self, spec_info: "EagleDraftInput"): + if self.future_indices is not None: + assert spec_info.future_indices is not None + self.future_indices = FutureIndices( + indices=torch.cat( + [self.future_indices.indices, spec_info.future_indices.indices] + ) + ) + self.allocate_lens = torch.cat( + [self.allocate_lens, spec_info.allocate_lens] + ) + return + if self.hidden_states is None: self.hidden_states = spec_info.hidden_states self.verified_id = spec_info.verified_id diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py new file mode 100644 index 000000000..23902a846 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -0,0 +1,514 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.scheduler import global_server_args_dict +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.speculative.build_eagle_tree import TreeMaskMode +from sglang.srt.speculative.spec_utils import ( + SIMULATE_ACC_LEN, + generate_simulated_accept_index, +) +from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2 + +if TYPE_CHECKING: + from sglang.srt.managers.tp_worker import TpModelWorker + from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( + EAGLEDraftCudaGraphRunner, + ) + from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput + +if is_cuda(): + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + verify_tree_greedy, + ) + from sgl_kernel.top_k import fast_topk +elif is_hip(): + from sgl_kernel import verify_tree_greedy + + +@triton.jit +def assign_draft_cache_locs_page_size_1( + req_pool_indices, + req_to_token, + seq_lens, + out_cache_loc, + pool_len: tl.constexpr, + topk: tl.constexpr, + speculative_num_steps: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + pid = tl.program_id(axis=0) + + copy_len = topk * speculative_num_steps + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + + # Copy from req_to_token to out_cache_loc + kv_start = tl.load(seq_lens + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + num_loop = tl.cdiv(copy_len, BLOCK_SIZE) + for i in range(num_loop): + copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = copy_offset < copy_len + data = tl.load(token_pool + kv_start + copy_offset, mask=mask) + tl.store(out_cache_ptr + copy_offset, data, mask=mask) + + +@dataclass +class EagleDraftInputV2Mixin: + def prepare_for_v2_draft( + self: EagleDraftInput, + req_to_token_pool: ReqToTokenPool, + batch: ModelWorkerBatch, + cuda_graph_runner: EAGLEDraftCudaGraphRunner, + draft_model_runner: ModelRunner, + topk: int, + num_steps: int, + ): + bs = len(batch.seq_lens) + + # Assign cache locations + batch.out_cache_loc = torch.empty( + (bs * topk * num_steps,), + dtype=torch.int64, + device=batch.input_ids.device, + ) + # FIXME(lsyin): align with the default code path + assign_draft_cache_locs_page_size_1[(bs,)]( + batch.req_pool_indices, + req_to_token_pool.req_to_token, + batch.seq_lens, + batch.out_cache_loc, + req_to_token_pool.req_to_token.shape[1], + topk, + num_steps, + ) + + # Get a forward batch + batch.capture_hidden_mode = CaptureHiddenMode.LAST + self.positions = batch.seq_lens.repeat_interleave(topk, dim=0) + forward_batch = ForwardBatch.init_new(batch, draft_model_runner) + can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch) + return forward_batch, can_cuda_graph + + def prepare_for_extend_to_fill_draft_kvcache( + self, + batch: ModelWorkerBatch, + predict: torch.Tensor, + num_draft_tokens: int, + draft_model_runner: Any, + ): + seq_lens_cpu_backup = batch.seq_lens_cpu + extend_num_tokens = len(batch.seq_lens) * num_draft_tokens + + batch.spec_info = self + batch.input_ids = predict + batch.seq_lens = batch.seq_lens + num_draft_tokens + batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens + batch.seq_lens_sum += extend_num_tokens + batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))] + batch.extend_prefix_lens = seq_lens_cpu_backup.tolist() + batch.extend_prefix_lens_cpu = seq_lens_cpu_backup + batch.extend_num_tokens = extend_num_tokens + batch.capture_hidden_mode = CaptureHiddenMode.FULL + batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2 + forward_batch = ForwardBatch.init_new(batch, draft_model_runner) + draft_model_runner.attn_backend.init_forward_metadata(forward_batch) + return forward_batch + + +@dataclass +class EagleVerifyInputV2Mixin: + def prepare_for_v2_verify( + self: EagleVerifyInput, + req_to_token_pool: ReqToTokenPool, + batch: ModelWorkerBatch, + target_worker: TpModelWorker, + ): + # Assign cache locations + bs = len(batch.req_pool_indices) + batch.input_ids = self.draft_token + device = batch.input_ids.device + batch.out_cache_loc = torch.empty( + (bs * self.draft_token_num,), + dtype=torch.int64, + device=device, + ) + + assign_extend_cache_locs[(bs,)]( + batch.req_pool_indices, + req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + self.draft_token_num, + batch.out_cache_loc, + req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + + # Get a forward batch + batch.forward_mode = ForwardMode.TARGET_VERIFY + batch.capture_hidden_mode = CaptureHiddenMode.FULL + verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner) + + # Run attention backend plan and cuda graph preparation + can_run_cuda_graph = bool( + target_worker.model_runner.graph_runner + and target_worker.model_runner.graph_runner.can_run(verify_forward_batch) + ) + if can_run_cuda_graph: + target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch) + else: + target_worker.model_runner.attn_backend.init_forward_metadata( + verify_forward_batch + ) + + return verify_forward_batch, can_run_cuda_graph + + def sample( + self: EagleVerifyInput, + batch: ModelWorkerBatch, + logits_output: LogitsProcessorOutput, + ): + """ + Verify and find accepted tokens based on logits output and batch + (which contains spec decoding information). + """ + bs = len(batch.seq_lens) + sampling_info = batch.sampling_info + next_token_logits = logits_output.next_token_logits + device = batch.input_ids.device + + candidates = self.draft_token.reshape(bs, self.draft_token_num) + predict = torch.zeros( + (bs * (self.spec_steps + 1),), dtype=torch.int32, device=device + ) + accept_index = torch.full( + (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device + ) + accept_length = torch.empty((bs,), dtype=torch.int32, device=device) + + # Sample tokens + if sampling_info.is_all_greedy: + target_predict = torch.argmax(next_token_logits, dim=-1) + target_predict = target_predict.reshape(bs, self.draft_token_num) + + verify_tree_greedy( + predicts=predict, # mutable + accept_index=accept_index, # mutable + accept_token_num=accept_length, # mutable + candidates=candidates, + retrive_index=self.retrive_index, + retrive_next_token=self.retrive_next_token, + retrive_next_sibling=self.retrive_next_sibling, + target_predict=target_predict, + ) + else: + # Apply temperature and get target probs + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, self.draft_token_num, dim=0 + ) # (bs * num_draft_tokens, 1) + + target_probs = F.softmax( + next_token_logits / expanded_temperature, dim=-1 + ) # (bs * num_draft_tokens, vocab_size) + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave( + sampling_info.top_ks, self.draft_token_num, dim=0 + ), + ) # (bs * num_draft_tokens, vocab_size) + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave( + sampling_info.top_ps, self.draft_token_num, dim=0 + ), + ) + target_probs = target_probs.reshape(bs, self.draft_token_num, -1) + + # This is currently not used + draft_probs = torch.empty_like(target_probs) + + # coins for rejection sampling + coins = torch.rand_like(candidates, dtype=torch.float32, device=device) + # coins for final sampling + coins_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=device + ) + + tree_speculative_sampling_target_only( + predicts=predict, # mutable + accept_index=accept_index, # mutable + accept_token_num=accept_length, # mutable + candidates=candidates, + retrive_index=self.retrive_index, + retrive_next_token=self.retrive_next_token, + retrive_next_sibling=self.retrive_next_sibling, + uniform_samples=coins, + uniform_samples_for_final_sampling=coins_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict[ + "speculative_accept_threshold_acc" + ], + deterministic=True, + ) + + if SIMULATE_ACC_LEN > 0: + # Do simulation + accept_index = generate_simulated_accept_index( + accept_index=accept_index, + predict=predict, # mutable + accept_length=accept_length, # mutable + simulate_acc_len=SIMULATE_ACC_LEN, + bs=bs, + spec_steps=self.draft_token_num, + ) + + # Include the bonus token + accept_length.add_(1) + return predict, accept_length, accept_index + + +def build_tree_kernel_efficient_tmp( + verified_id: torch.Tensor, + parent_list: List[torch.Tensor], + top_scores_index: torch.Tensor, + draft_tokens: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + topk: int, + spec_steps: int, + num_verify_tokens: int, + tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK, + tree_mask_buf: Optional[torch.Tensor] = None, + position_buf: Optional[torch.Tensor] = None, +): + # TODO(lsyin): make it compatible with default code path + # TODO(lsyin): support cuda graph graph padding for eagle + draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() + + # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens + bs = seq_lens.numel() + device = seq_lens.device + # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) + # where each row indicates the attending pattern of each draft token + # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed) + if tree_mask_buf is not None: + tree_mask = tree_mask_buf + if tree_mask_mode == TreeMaskMode.QLEN_ONLY: + tree_mask.fill_(True) + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: + tree_mask.fill_(0) + elif tree_mask_mode == TreeMaskMode.FULL_MASK: + tree_mask.fill_(True) + else: + raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY: + tree_mask = torch.full( + (num_verify_tokens * bs * num_verify_tokens,), + True, + dtype=torch.bool, + device=device, + ) + elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: + packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] + packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) + tree_mask = torch.zeros( + (num_verify_tokens * bs,), + dtype=packed_dtypes[packed_dtype_idx], + device=device, + ) + elif tree_mask_mode == TreeMaskMode.FULL_MASK: + tree_mask = torch.full( + ( + seq_lens_sum * num_verify_tokens + + num_verify_tokens * num_verify_tokens * bs, + ), + True, + device=device, + ) + else: + raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") + + # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` + retrive_buf = torch.full( + (3, bs, num_verify_tokens), -1, device=device, dtype=torch.long + ) + retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf + # position: where each token belongs to + # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 + # then, positions = [7, 8, 8, 9] + if position_buf is not None: + positions = position_buf + else: + positions = torch.empty( + (bs * num_verify_tokens,), device=device, dtype=torch.long + ) + + from sgl_kernel import ( + build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, + ) + + sgl_build_tree_kernel_efficient( + parent_list, + top_scores_index, + seq_lens, + tree_mask, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + topk, + spec_steps, + num_verify_tokens, + tree_mask_mode, + ) + return ( + tree_mask, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) + + +@torch.compile(dynamic=True) +def select_top_k_tokens_tmp( + i: int, + topk_p: torch.Tensor, + topk_index: torch.Tensor, + hidden_states: torch.Tensor, + scores: torch.Tensor, + topk: int, +): + # FIXME(lsyin): remove this duplicate code + if i == 0: + # The first step after extend + input_ids = topk_index.flatten() + hidden_states = hidden_states.repeat_interleave(topk, dim=0) + scores = topk_p # shape: (b, topk) + + tree_info = ( + topk_p.unsqueeze(1), # shape: (b, 1, topk) + topk_index, # shape: (b, topk) + torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device) + .unsqueeze(0) + .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) + ) + else: + # The later decode steps + expand_scores = torch.mul( + scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) + topk_cs_p, topk_cs_index = fast_topk( + expand_scores.flatten(start_dim=1), topk, dim=-1 + ) # (b, topk) + scores = topk_cs_p # shape: (b, topk) + + topk_index = topk_index.reshape(-1, topk**2) + input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() + + selected_input_index = topk_cs_index.flatten() // topk + torch.arange( + 0, hidden_states.shape[0], step=topk, device=hidden_states.device + ).repeat_interleave(topk) + hidden_states = hidden_states[selected_input_index, :] + + tree_info = ( + expand_scores, # shape: (b, topk, topk) + topk_index, # shape: (b, topk * topk) + topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) + ) + + return input_ids, hidden_states, scores, tree_info + + +@triton.jit +def fill_new_verified_id( + verified_id, + accept_lens, + new_verified_id, + num_draft_tokens: tl.constexpr, +): + # NOTE: we cannot fuse any in-place operations of `accept_lens` inside this kernel + # because this kernel reads accept_lens + pid = tl.program_id(axis=0) + accept_length = tl.load(accept_lens + pid) + + verified_id_idx = num_draft_tokens * pid + accept_length - 1 + verified_id_data = tl.load(verified_id + verified_id_idx) + tl.store(new_verified_id + pid, verified_id_data) + + +@triton.jit +def fill_accepted_out_cache_loc( + accept_index, + out_cache_loc, + accepted_out_cache_loc, + size_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = tl.arange(0, size_upper) + + masks = (tl.load(accept_index + offset, offset < pid, other=-1) != -1).to(tl.int64) + dst = tl.sum(masks) + src = tl.load(accept_index + pid) + if src > -1: + value = tl.load(out_cache_loc + src) + tl.store(accepted_out_cache_loc + dst, value) + + +@triton.jit +def assign_extend_cache_locs( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) + end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + load_offset = tl.arange(0, BLOCK_SIZE) + kv_start + save_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = load_offset < kv_end + data = tl.load(token_pool + load_offset, mask=mask) + tl.store(out_cache_ptr + save_offset, data, mask=mask) + load_offset += BLOCK_SIZE + save_offset += BLOCK_SIZE diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py new file mode 100644 index 000000000..fb01eba53 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -0,0 +1,482 @@ +import logging +from typing import List, Optional + +import torch +from torch.cuda import Stream as CudaStream + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.build_eagle_tree import TreeMaskMode +from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.eagle_info_v2 import ( + assign_extend_cache_locs, + build_tree_kernel_efficient_tmp, + fill_accepted_out_cache_loc, + fill_new_verified_id, + select_top_k_tokens_tmp, +) +from sglang.srt.speculative.eagle_worker import EAGLEWorker +from sglang.srt.utils.common import fast_topk, next_power_of_2 + +logger = logging.getLogger(__name__) + + +class EAGLEWorkerV2(EAGLEWorker): + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + super().__init__( + server_args, + gpu_id, + tp_rank, + dp_rank, + moe_ep_rank, + nccl_port, + target_worker, + ) + EagleDraftInput.ALLOC_LEN_PER_DECODE = max( + self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens + ) + self.tree_mask_mode = TreeMaskMode.FULL_MASK + self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream() + # TODO(lsyin): potential bugs with a separate plan stream + self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) + + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + if model_worker_batch.forward_mode.is_decode(): + # FIXME(lsyin): why shall we use spec_info for both draft and verify? + draft_input: EagleDraftInput = model_worker_batch.spec_info + assert draft_input.is_draft_input() + verify_input: EagleVerifyInput = self.draft(model_worker_batch) + assert verify_input.is_verify_input() + model_worker_batch.spec_info = verify_input + batch_output = self.verify(model_worker_batch, draft_input.allocate_lens) + return batch_output + else: + # Target prefill + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + batch_output = self.target_worker.forward_batch_generation( + model_worker_batch + ) + + # Draft prefill + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST + batch_output.next_draft_input = self.forward_draft_extend( + model_worker_batch, + batch_output.logits_output.hidden_states, + batch_output.next_token_ids, + ) + return batch_output + + def draft(self, model_worker_batch: ModelWorkerBatch): + draft_input: EagleDraftInput = model_worker_batch.spec_info + forward_batch, can_cuda_graph = draft_input.prepare_for_v2_draft( + self.req_to_token_pool, + model_worker_batch, + self.cuda_graph_runner, + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + + # Run draft + if can_cuda_graph: + parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay( + forward_batch, + ) + else: + self.draft_attn_backend.init_forward_metadata(forward_batch) + parent_list, top_scores_index, draft_tokens = self.draft_forward( + forward_batch + ) + + # Build tree mask + # Directly write to cuda graph buffers for verify attn + tree_mask_buf, position_buf = ( + self.target_worker.model_runner.attn_backend.get_verify_buffers_to_fill_after_draft() + ) + + ( + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient_tmp( + draft_input.verified_id, + parent_list, + top_scores_index, + draft_tokens, + model_worker_batch.seq_lens, + model_worker_batch.seq_lens_sum, + self.topk, + self.speculative_num_steps, + self.speculative_num_draft_tokens, + self.tree_mask_mode, + tree_mask_buf, + position_buf, + ) + + return EagleVerifyInput( + draft_token=draft_tokens, + custom_mask=tree_mask, + positions=position, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + retrive_cum_len=None, + spec_steps=self.speculative_num_steps, + topk=self.topk, + draft_token_num=self.speculative_num_draft_tokens, + capture_hidden_mode=None, + seq_lens_sum=None, + seq_lens_cpu=None, + ) + + def draft_forward(self, forward_batch: ForwardBatch): + # Parse args + spec_info: EagleDraftInput = forward_batch.spec_info + out_cache_loc = forward_batch.out_cache_loc + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] + + out_cache_loc = out_cache_loc.reshape( + forward_batch.batch_size, self.topk, self.speculative_num_steps + ) + out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape( + self.speculative_num_steps, -1 + ) + + # Return values + score_list: List[torch.Tensor] = [] + token_list: List[torch.Tensor] = [] + parents_list: List[torch.Tensor] = [] + + # Forward multiple steps + scores = None + for i in range(self.speculative_num_steps): + input_ids, hidden_states, scores, tree_info = select_top_k_tokens_tmp( + i, topk_p, topk_index, hidden_states, scores, self.topk + ) + score_list.append(tree_info[0]) + token_list.append(tree_info[1]) + parents_list.append(tree_info[2]) + + # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here + if i == self.speculative_num_steps - 1: + break + + # Set inputs + forward_batch.input_ids = input_ids + forward_batch.out_cache_loc = out_cache_loc[i] + forward_batch.positions.add_(1) + forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] + spec_info.hidden_states = hidden_states + + # Run forward + logits_output = self.draft_model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + self._detect_nan_if_needed(logits_output) + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] + hidden_states = logits_output.hidden_states + + # Organize the results + score_list = torch.cat(score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1 + (num_steps-1) * self.topk + ss_token_list = torch.cat( + token_list, dim=1 + ) # b, (self.topk + (num_steps-1) * self.topk) + top_scores = torch.topk( + score_list, self.speculative_num_draft_tokens - 1, dim=-1 + ) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + + if len(parents_list) > 1: + parent_list = torch.cat(parents_list[:-1], dim=1) + else: + batch_size = parents_list[0].shape[0] + parent_list = torch.empty(batch_size, 0, device=parents_list[0].device) + + return parent_list, top_scores_index, draft_tokens + + def verify( + self, + batch: ModelWorkerBatch, + pre_draft_allocate_lens: torch.Tensor, + ): + # Parse args + verify_input: EagleVerifyInput = batch.spec_info + seq_lens_backup = batch.seq_lens + bs = len(batch.seq_lens) + + # Batch 1: Target verify + # Prepare for target verify in a separate stream + with self.plan_stream_ctx: + verify_forward_batch, can_run_cuda_graph = ( + verify_input.prepare_for_v2_verify( + self.req_to_token_pool, + batch, + self.target_worker, + ) + ) + + # Correct some buffers due to the overlap plan + if self.plan_stream: + torch.cuda.current_stream().wait_stream(self.plan_stream) + + # Some values such as custom_mask and position depend on the output of draft, + # so the previous plan step used the wrong values. Here, we need to run the related + # computation again to update them to the correct values. + self.target_worker.model_runner.attn_backend.update_verify_buffers_to_fill_after_draft( + verify_input, + ( + self.target_worker.model_runner.graph_runner.bs + if can_run_cuda_graph + else None + ), + ) + + # Run target verify batch in the main compute stream + forward_batch_output = self.target_worker.forward_batch_generation( + model_worker_batch=None, + forward_batch=verify_forward_batch, + is_verify=True, + skip_attn_backend_init=True, + ) + logits_output = forward_batch_output.logits_output + + # Sample + self._detect_nan_if_needed(logits_output) + ( + predict, + accept_length, + accept_index, + ) = verify_input.sample(batch, logits_output) + new_seq_lens = seq_lens_backup + accept_length + verify_done = torch.cuda.Event() + + # Move the accepted tokens to the target KV cache locations + batch.seq_lens = seq_lens_backup + self.move_accepted_tokens_to_target_kvcache( + batch, + accept_index, + accept_length, + ) + + verify_done.record() + + all_verified_id = predict[accept_index] + verified_id = torch.empty_like(accept_length, dtype=torch.int32) + fill_new_verified_id[(bs,)]( + all_verified_id, + accept_length, + verified_id, + self.speculative_num_draft_tokens, + ) + + # Batch 2: Draft extend + draft_input = EagleDraftInput( + hidden_states=logits_output.hidden_states, + ) + select_index = ( + torch.arange(len(batch.seq_lens), device=self.device) + * self.speculative_num_draft_tokens + + accept_length + - 1 + ) + + # Prepare for draft extend in a separate stream + with self.plan_stream_ctx: + forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache( + batch, + predict, + self.speculative_num_draft_tokens, + self.draft_model_runner, + ) + + if self.plan_stream: + torch.cuda.current_stream().wait_stream(self.plan_stream) + + # Run draft extend batch in the main compute stream + draft_logits_output = self.draft_model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + + # Reorganize the spec info for the next batch + draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[ + select_index + ] + draft_logits_output.hidden_states = draft_logits_output.hidden_states[ + select_index + ] + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + ret_hidden_states = draft_logits_output.hidden_states + + # Since seq_lens_backup's tensor is allocated in another stream, we + # need record_stream() to prevent pytorch gc and reuse the gpu memory + # while forward_stream is still running. + seq_lens_backup.record_stream(torch.cuda.current_stream()) + + # Construct the return values + next_draft_input = EagleDraftInput( + topk_p=ret_topk_p, + topk_index=ret_topk_index, + hidden_states=ret_hidden_states, + verified_id=verified_id, + new_seq_lens=new_seq_lens, + allocate_lens=pre_draft_allocate_lens, + verify_done=verify_done, + ) + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=predict, + can_run_cuda_graph=can_run_cuda_graph, + next_draft_input=next_draft_input, + accept_lens=accept_length, + last_batch_allocate_lens=pre_draft_allocate_lens, + ) + + def forward_draft_extend( + self, + batch: ModelWorkerBatch, + target_hidden_states: torch.Tensor, + next_token_ids: torch.Tensor, + ): + """ + Run draft model extend to correctly fill the KV cache. + + Args: + batch: The batch to run. + target_hidden_states: Hidden states from the target model forward + next_token_ids: Next token ids generated from the target forward. + """ + # Construct input_ids + pt = 0 + for i, extend_len in enumerate(batch.extend_seq_lens): + input_ids = batch.input_ids[pt : pt + extend_len] + batch.input_ids[pt : pt + extend_len] = torch.cat( + (input_ids[1:], next_token_ids[i].reshape(1)) + ) + pt += extend_len + + # Construct spec_info + next_draft_input = EagleDraftInput( + hidden_states=target_hidden_states, + verified_id=next_token_ids, + new_seq_lens=batch.seq_lens, + allocate_lens=batch.seq_lens, + ) + batch.spec_info = next_draft_input + + # Run forward + forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner) + logits_output, _ = self.draft_model_runner.forward(forward_batch) + + # Update spec_info for the next draft step + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + next_draft_input.topk_p, next_draft_input.topk_index = fast_topk( + probs, self.topk, dim=-1 + ) + next_draft_input.hidden_states = logits_output.hidden_states + return next_draft_input + + def move_accepted_tokens_to_target_kvcache( + self, + batch: ModelWorkerBatch, + accept_index: torch.Tensor, + accept_length: torch.Tensor, + ): + """ + Move accepted tokens to the target KV cache. + + Args: + batch: The batch to run. + accept_index: The index of the accepted tokens. + accept_length: The length of the accepted tokens. + """ + bs = len(batch.seq_lens) + size = bs * self.speculative_num_draft_tokens + + tgt_cache_loc = torch.zeros( + size, + dtype=torch.int64, + device=self.device, + ) + accepted_out_cache_loc = torch.zeros( + size, dtype=torch.int64, device=self.device + ) + assign_extend_cache_locs[(bs,)]( + batch.req_pool_indices, + self.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length, + tgt_cache_loc, + self.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + fill_accepted_out_cache_loc[(size,)]( + accept_index, + batch.out_cache_loc, + accepted_out_cache_loc, + next_power_of_2(size), + ) + self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache( + tgt_cache_loc, accepted_out_cache_loc + ) + + def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): + if self.enable_nan_detection: + logits = logits_output.next_token_logits + if torch.any(torch.isnan(logits)): + logger.error("Detected errors during sampling! NaN in the logits.") + raise ValueError("Detected errors during sampling! NaN in the logits.") + + +def free_spec_dec_tokens_page_size_1( + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + req: Req, + allocate_len: int, + new_seq_len: int, +): + # FIXME(lsyin): move this function elsewhere + + # free extra allocated tokens + if new_seq_len is None: + # True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case). + start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE + else: + # True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in. + start_len = new_seq_len + indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][ + start_len:allocate_len + ] + token_to_kv_pool_allocator.free(indices_to_free) diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py index 8478ac14c..4c3c8a070 100644 --- a/python/sglang/srt/speculative/spec_utils.py +++ b/python/sglang/srt/speculative/spec_utils.py @@ -435,7 +435,7 @@ def select_top_k_tokens( return input_ids, hidden_states, scores, tree_info -def _generate_simulated_accept_index( +def generate_simulated_accept_index( accept_index, predict, accept_length, diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index b62ad4136..61e45440b 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -4,7 +4,7 @@ import copy import dataclasses import logging from dataclasses import replace -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence import torch @@ -30,12 +30,12 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy -from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import DispatchOutput + from sglang.srt.speculative.eagle_info import EagleVerifyInput _is_hip = is_hip() diff --git a/scripts/sort_testcases_alphabetically.py b/scripts/sort_testcases_alphabetically.py index efe3020d0..1d13a0cb4 100644 --- a/scripts/sort_testcases_alphabetically.py +++ b/scripts/sort_testcases_alphabetically.py @@ -67,6 +67,7 @@ suites = { TestFile("test_deterministic.py", 300), TestFile("test_eagle_infer_a.py", 370), TestFile("test_eagle_infer_b.py", 700), + TestFile("test_eagle_infer_beta.py", 300), TestFile("test_ebnf_constrained.py", 108), TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_fa3.py", 376), diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 995a8dc98..fddc5543a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -69,6 +69,7 @@ suites = { TestFile("test_deterministic.py", 300), TestFile("test_eagle_infer_a.py", 370), TestFile("test_eagle_infer_b.py", 700), + TestFile("test_eagle_infer_beta.py", 300), TestFile("test_ebnf_constrained.py", 108), TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_fa3.py", 376), diff --git a/test/srt/test_eagle_infer_beta.py b/test/srt/test_eagle_infer_beta.py new file mode 100644 index 000000000..fe7f18010 --- /dev/null +++ b/test/srt/test_eagle_infer_beta.py @@ -0,0 +1,125 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestEagleBS1(CustomTestCase): + num_questions = 60 + + @classmethod + def setUpClass(cls): + cls.model = "meta-llama/Llama-2-7b-chat-hf" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--attention-backend", + "triton", + "--enable-beta-spec", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model", + "lmzheng/sglang-EAGLE-llama2-chat-7B", + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "6", + "--max-running-requests", + "1", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=self.num_questions, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"TestEagleBS1 -- {metrics=}") + self.assertGreater( + metrics["accuracy"], 0.33 + ) # 0.3333 for 60 questions; 0.234 for 1319 questions + + +class TestEagleLargeBS(CustomTestCase): + num_questions = 10000 + max_running_requests = 64 + other_args = [ + "--trust-remote-code", + "--attention-backend", + "triton", + "--enable-beta-spec", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model", + "lmzheng/sglang-EAGLE-llama2-chat-7B", + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "6", + "--mem-fraction-static", + "0.75", + "--max-running-requests", + str(max_running_requests), + "--cuda-graph-bs", + *[str(i) for i in range(1, max_running_requests + 1)], + ] + + @classmethod + def setUpClass(cls): + cls.model = "meta-llama/Llama-2-7b-chat-hf" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=self.num_questions, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"TestEagleLargeBS -- {metrics=}") + self.assertGreater( + metrics["accuracy"], 0.23 + ) # 0.3333 for 60 questions; 0.234 for 1319 questions + + +if __name__ == "__main__": + unittest.main()