diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index de846066e..dbe52399f 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -230,7 +230,7 @@ def extend(reqs, model_runner): batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool=model_runner.token_to_kv_pool, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, tree_cache=None, model_config=model_runner.model_config, enable_overlap=False, @@ -326,7 +326,7 @@ def latency_test_run_once( # Clear the pools. model_runner.req_to_token_pool.clear() - model_runner.token_to_kv_pool.clear() + model_runner.token_to_kv_pool_allocator.clear() measurement_results = { "run_name": run_name, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index cf33ee257..b48fecb26 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -20,14 +20,15 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo if is_flashinfer_available(): from flashinfer import ( @@ -36,6 +37,7 @@ if is_flashinfer_available(): BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state + from flashinfer.decode import PosEncodingMode class WrapperDispatch(Enum): @@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend): device=model_runner.device, ) self.workspace_buffer = global_workspace_buffer + max_bs = model_runner.req_to_token_pool.size if kv_indptr_buf is None: self.kv_indptr = [ @@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend): assert self.num_wrappers == 1 self.kv_last_page_len = kv_last_page_len_buf - self.qo_indptr = [ - torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - for _ in range(self.num_wrappers) - ] + if not self.skip_prefill: + self.qo_indptr = [ + torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + for _ in range(self.num_wrappers) + ] self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.workspace_buffer, "NHD" @@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): if forward_mode.is_decode_or_idle(): decode_wrappers = [] @@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( @@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum: int, kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): if spec_info is None: bs = len(req_pool_indices) @@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode: self.req_to_token.shape[1], ) else: + assert isinstance(spec_info, EagleDraftInput) kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - wrapper.begin_forward( kv_indptr, kv_indices, @@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): if use_ragged: paged_kernel_lens = prefix_lens @@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): - bs = len(req_pool_indices) + bs = len(seq_lens) if spec_info is None: + assert len(seq_lens) == len(req_pool_indices) # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] @@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill: qo_indptr = qo_indptr[: bs + 1] custom_mask = None else: + assert isinstance(spec_info, EagleDraftInput) or isinstance( + spec_info, EagleVerifyInput + ) kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( req_pool_indices, paged_kernel_lens, + paged_kernel_lens_sum, self.req_to_token, ) ) @@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill: ) +# Use as a fast path to override the indptr in flashinfer's plan function +# This is used to remove some host-to-device copy overhead. +global global_override_indptr_cpu + + class FlashInferMultiStepDraftBackend: """ Wrap multiple flashinfer attention backends as one for multiple consecutive @@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend: self.topk = topk self.speculative_num_steps = speculative_num_steps self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + max_bs = model_runner.req_to_token_pool.size * self.topk self.kv_indptr = torch.zeros( ( @@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend: kv_last_page_len_buf=self.kv_last_page_len, ) ) + self.max_context_len = self.attn_backends[0].max_context_len + # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] @@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend: triton.next_power_of_2(bs), ) + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan. + indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu() + global global_override_indptr_cpu + for i in range(self.speculative_num_steps - 1): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ : seq_lens_sum * self.topk + bs * (i + 1) ] + global_override_indptr_cpu = indptr_cpu_whole[i] call_fn(i, forward_batch) + global_override_indptr_cpu = None + def init_forward_metadata(self, forward_batch: ForwardBatch): kv_indices = torch.zeros( ( @@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend: ) def call_fn(i, forward_batch): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr.clone() ) @@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend: dtype=torch.int32, device="cuda", ) + for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] @@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend: self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - kv_indices_ptr, - req_to_token_ptr_stride: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < kv_end - kv_start - data = tl.load( - req_to_token_ptr - + req_pool_index * req_to_token_ptr_stride - + kv_start - + offset, - mask=mask, - ) - tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) - - def should_use_tensor_core( kv_cache_dtype: torch.dtype, num_attention_heads: int, @@ -1089,6 +1084,21 @@ def should_use_tensor_core( if env_override is not None: return env_override.lower() == "true" + # Try to use _grouped_size_compiled_for_decode_kernels if available + # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug + try: + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + if not _grouped_size_compiled_for_decode_kernels( + num_attention_heads, + num_kv_heads, + ): + return True + else: + return False + except (ImportError, AttributeError): + pass + # Calculate GQA group size gqa_group_size = num_attention_heads // num_kv_heads @@ -1118,12 +1128,18 @@ def fast_decode_plan( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - **kwargs, + non_blocking: bool = True, ) -> None: - """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" + """ + A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend. + Modifications: + - Remove unnecessary device-to-device copy for the cuda graph buffers. + - Remove unnecessary host-to-device copy for the metadata buffers. + """ batch_size = len(last_page_len) if logits_soft_cap is None: logits_soft_cap = 0.0 + if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( @@ -1136,13 +1152,19 @@ def fast_decode_plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) + # Skip these copies + # self._paged_kv_indptr_buf.copy_(indptr) + # self._paged_kv_indices_buf[: len(indices)] = indices + # self._paged_kv_last_page_len_buf.copy_(last_page_len) else: self._paged_kv_indptr_buf = indptr self._paged_kv_indices_buf = indices self._paged_kv_last_page_len_buf = last_page_len + # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info if not q_data_type: q_data_type = data_type + if not hasattr(self, "empty_q_data"): self.empty_q_data = torch.empty( 0, @@ -1159,6 +1181,7 @@ def fast_decode_plan( ), ) self.last_page_len = torch.ones(32768, dtype=torch.int32) + empty_q_data = self.empty_q_data empty_kv_cache = self.empty_kv_cache stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 23387b5a1..a8cf3abb8 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend): spec_info.generate_attn_arg_prefill( forward_batch.req_pool_indices, forward_batch.seq_lens, + None, self.req_to_token, ) ) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 1a8ad6a6e..993c9e5c2 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -22,7 +22,7 @@ from typing import List, Optional import torch -from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost logger = logging.getLogger(__name__) @@ -128,7 +128,7 @@ class HiCacheController: def __init__( self, mem_pool_device: MHATokenToKVPool, - mem_pool_host: MLATokenToKVPoolHost, + mem_pool_host: MHATokenToKVPoolHost, write_policy: str = "write_through_selective", ): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 23d1454ae..3f09915ba 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: - from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm - INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -523,7 +521,7 @@ class ScheduleBatch: # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool = None - token_to_kv_pool: BaseTokenToKVPool = None + token_to_kv_pool_allocator: TokenToKVPoolAllocator = None tree_cache: BasePrefixCache = None # Batch configs @@ -596,7 +594,7 @@ class ScheduleBatch: cls, reqs: List[Req], req_to_token_pool: ReqToTokenPool, - token_to_kv_pool: ReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, @@ -606,7 +604,7 @@ class ScheduleBatch: return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, - token_to_kv_pool=token_to_kv_pool, + token_to_kv_pool_allocator=token_to_kv_pool_allocator, tree_cache=tree_cache, model_config=model_config, enable_overlap=enable_overlap, @@ -637,19 +635,19 @@ class ScheduleBatch: return req_pool_indices def alloc_token_slots(self, num_tokens: int): - out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) + out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) if out_cache_loc is None: if self.tree_cache is not None: - self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free) - out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) + self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free) + out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) if out_cache_loc is None: phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" logger.error( f"{phase_str} out of memory. Try to lower your batch size.\n" f"Try to allocate {num_tokens} tokens.\n" - f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n" + f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" ) if self.tree_cache is not None: self.tree_cache.pretty_print() @@ -917,12 +915,12 @@ class ScheduleBatch: def check_decode_mem(self, buf_multiplier=1): bs = len(self.reqs) * buf_multiplier - if self.token_to_kv_pool.available_size() >= bs: + if self.token_to_kv_pool_allocator.available_size() >= bs: return True - self.tree_cache.evict(bs, self.token_to_kv_pool.free) + self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free) - if self.token_to_kv_pool.available_size() >= bs: + if self.token_to_kv_pool_allocator.available_size() >= bs: return True return False @@ -945,6 +943,10 @@ class ScheduleBatch: reverse=True, ) + retracted_reqs = [] + seq_lens_cpu = self.seq_lens.cpu().numpy() + first_iter = True + def get_required_tokens(num_reqs: int): headroom_for_spec_decode = 0 if server_args.speculative_algorithm: @@ -958,18 +960,15 @@ class ScheduleBatch: num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode ) - retracted_reqs = [] - seq_lens_cpu = self.seq_lens.cpu().numpy() - first_iter = True while ( - self.token_to_kv_pool.available_size() + self.token_to_kv_pool_allocator.available_size() < get_required_tokens(len(sorted_indices)) or first_iter ): if len(sorted_indices) == 1: # Corner case: only one request left assert ( - self.token_to_kv_pool.available_size() > 0 + self.token_to_kv_pool_allocator.available_size() > 0 ), "No space left for only one request" break @@ -983,7 +982,7 @@ class ScheduleBatch: token_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : seq_lens_cpu[idx] ] - self.token_to_kv_pool.free(token_indices) + self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) del self.tree_cache.entries[req.rid] else: @@ -992,7 +991,7 @@ class ScheduleBatch: token_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] ] - self.token_to_kv_pool.free(token_indices) + self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) # release the last node @@ -1001,10 +1000,13 @@ class ScheduleBatch: # NOTE(lsyin): we should use the newly evictable memory instantly. residual_size = ( len(sorted_indices) * global_config.retract_decode_steps - - self.token_to_kv_pool.available_size() + - self.token_to_kv_pool_allocator.available_size() ) residual_size = max(0, residual_size) - self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) + self.tree_cache.evict( + residual_size, self.token_to_kv_pool_allocator.free + ) + req.reset_for_retract() self.filter_batch(keep_indices=sorted_indices) @@ -1183,7 +1185,7 @@ class ScheduleBatch: if self.spec_info: self.spec_info.merge_batch(other.spec_info) - def get_model_worker_batch(self): + def get_model_worker_batch(self) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: @@ -1273,7 +1275,7 @@ class ModelWorkerBatch: req_pool_indices: torch.Tensor # The sequence length seq_lens: torch.Tensor - # The indices of output tokens in the token_to_kv_pool + # The indices of output tokens in the token_to_kv_pool_allocator out_cache_loc: torch.Tensor # The sum of all sequence lengths diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 916692446..de43c98f9 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union import torch -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ( + Req, + ScheduleBatch, + global_server_args_dict, +) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool +from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. @@ -75,7 +79,7 @@ class SchedulePolicy: # It is used to find the matching prefix for in-batch prefix caching. self.waiting_queue_radix_tree = RadixCache( - req_to_token_pool=None, token_to_kv_pool=None, disable=False + req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False ) def calc_priority(self, waiting_queue: List[Req]) -> bool: @@ -251,7 +255,7 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, - token_to_kv_pool: BaseTokenToKVPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, running_batch: ScheduleBatch, new_token_ratio: float, rem_input_tokens: int, @@ -259,7 +263,7 @@ class PrefillAdder: mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache - self.token_to_kv_pool = token_to_kv_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.running_batch = running_batch self.new_token_ratio = new_token_ratio self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens @@ -291,7 +295,7 @@ class PrefillAdder: @property def rem_total_tokens(self): return ( - self.token_to_kv_pool.available_size() + self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() - self.rem_total_token_offset ) @@ -299,7 +303,7 @@ class PrefillAdder: @property def cur_rem_tokens(self): return ( - self.token_to_kv_pool.available_size() + self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() - self.cur_rem_token_offset ) @@ -332,7 +336,6 @@ class PrefillAdder: req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] self.can_run_list.append(req) - self._prefill_one_req( 0, req.extend_input_len, @@ -400,8 +403,8 @@ class PrefillAdder: tokens_freed += tokens_occupied if ( - self.rem_chunk_tokens is None - or req.extend_input_len <= self.rem_chunk_tokens + self.rem_chunk_tokens is None # chunked prefill is disabled + or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk ): # Non-chunked prefill self.can_run_list.append(req) @@ -411,10 +414,11 @@ class PrefillAdder: min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION), ) else: + if self.rem_chunk_tokens == 0: + return AddReqResult.OTHER + # Chunked prefill trunc_len = self.rem_chunk_tokens - if trunc_len == 0: - return AddReqResult.OTHER req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[:trunc_len] @@ -457,10 +461,11 @@ class PrefillAdder: ), ) else: + if self.rem_chunk_tokens == 0: + return AddReqResult.OTHER + # Chunked prefill trunc_len = self.rem_chunk_tokens - if trunc_len == 0: - return AddReqResult.OTHER req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c47405c43..1ee04d3a7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -164,7 +164,7 @@ class Scheduler: self.server_args.speculative_num_draft_tokens + ( self.server_args.speculative_eagle_topk - * self.server_args.speculative_num_steps + * self.server_args.speculative_num_draft_tokens ) ) if not self.spec_algorithm.is_none() @@ -309,7 +309,9 @@ class Scheduler: ) # Init memory pool and cache - self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool() + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + self.tp_worker.get_memory_pool() + ) if ( server_args.chunked_prefill_size is not None @@ -317,18 +319,18 @@ class Scheduler: ): self.tree_cache = ChunkCache( req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool=self.token_to_kv_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) else: if self.enable_hierarchical_cache: self.tree_cache = HiRadixCache( req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool=self.token_to_kv_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) else: self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool=self.token_to_kv_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, disable=server_args.disable_radix_cache, ) @@ -458,7 +460,6 @@ class Scheduler: (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ProfileReq, self.profile), (GetInternalStateReq, self.get_internal_state), - (SetInternalStateReq, self.set_internal_state), ] ) @@ -809,7 +810,8 @@ class Scheduler: running_bs: int, ): num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.evictable_size() ) self._largest_prefill_len = max( self._largest_prefill_len, adder.log_input_tokens @@ -844,7 +846,8 @@ class Scheduler: self.num_generated_tokens = 0 num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.evictable_size() ) if RECORD_STEP_TIME: @@ -894,7 +897,8 @@ class Scheduler: def check_memory(self): available_size = ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.evictable_size() ) protected_size = self.tree_cache.protected_size() memory_leak = available_size != ( @@ -999,7 +1003,7 @@ class Scheduler: # Prefill policy adder = PrefillAdder( self.tree_cache, - self.token_to_kv_pool, + self.token_to_kv_pool_allocator, self.running_batch, self.new_token_ratio, self.max_prefill_tokens, @@ -1099,7 +1103,7 @@ class Scheduler: new_batch = ScheduleBatch.init_new( can_run_list, self.req_to_token_pool, - self.token_to_kv_pool, + self.token_to_kv_pool_allocator, self.tree_cache, self.model_config, self.enable_overlap, @@ -1143,8 +1147,6 @@ class Scheduler: retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) self.new_token_ratio = new_token_ratio - if self.draft_worker: - self.draft_worker.finish_request(retracted_reqs) logger.info( "Decode out of memory happened. " @@ -1184,11 +1186,12 @@ class Scheduler: logits_output, next_token_ids = self.tp_worker.forward_batch_generation( model_worker_batch ) + bid = model_worker_batch.bid else: ( logits_output, next_token_ids, - model_worker_batch, + bid, num_accepted_tokens, ) = self.draft_worker.forward_batch_speculative_generation(batch) self.spec_num_total_accepted_tokens += ( @@ -1214,7 +1217,7 @@ class Scheduler: next_token_ids=next_token_ids, extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, - bid=model_worker_batch.bid, + bid=bid, ) else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() @@ -1230,6 +1233,7 @@ class Scheduler: result: Union[GenerationBatchResult, EmbeddingBatchResult], ): if batch.forward_mode.is_decode(): + assert isinstance(result, GenerationBatchResult) self.process_batch_result_decode(batch, result) if batch.is_empty(): self.running_batch = None @@ -1302,7 +1306,7 @@ class Scheduler: if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1]) continue if req.is_chunked <= 0: @@ -1420,23 +1424,27 @@ class Scheduler: self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: + assert batch.spec_algorithm.is_none() logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) next_token_logprobs = logits_output.next_token_logprobs - else: + elif batch.spec_algorithm.is_none(): + # spec decoding handles output logprobs inside verify process. next_token_ids = next_token_ids.tolist() if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs.tolist() - self.token_to_kv_pool.free_group_begin() + self.token_to_kv_pool_allocator.free_group_begin() # Check finish condition + # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding. + # We should ignore using next_token_ids for spec decoding cases. for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if req.is_retracted: continue if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1479,7 +1487,7 @@ class Scheduler: batch.next_batch_sampling_info.sampling_info_done.set() self.stream_output(batch.reqs, batch.return_logprob) - self.token_to_kv_pool.free_group_end() + self.token_to_kv_pool_allocator.free_group_end() self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) if ( @@ -1718,9 +1726,6 @@ class Scheduler: and not self.model_config.is_multimodal_gen ) ): - if self.draft_worker and req.finished(): - self.draft_worker.finish_request(req) - rids.append(req.rid) finished_reasons.append( req.finished_reason.to_json() if req.finished_reason else None @@ -1860,7 +1865,7 @@ class Scheduler: idle_batch = ScheduleBatch.init_new( [], self.req_to_token_pool, - self.token_to_kv_pool, + self.token_to_kv_pool_allocator, self.tree_cache, self.model_config, self.enable_overlap, @@ -1916,11 +1921,11 @@ class Scheduler: if self.grammar_backend: self.grammar_backend.reset() self.req_to_token_pool.clear() - self.token_to_kv_pool.clear() + self.token_to_kv_pool_allocator.clear() if not self.spec_algorithm.is_none(): self.draft_worker.model_runner.req_to_token_pool.clear() - self.draft_worker.model_runner.token_to_kv_pool.clear() + self.draft_worker.model_runner.token_to_kv_pool_allocator.clear() self.num_generated_tokens = 0 self.forward_ct_decode = 0 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e7e062a3e..486f1d24c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import ( ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, SessionParams, - SetInternalStateReq, - SetInternalStateReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -257,9 +255,6 @@ class TokenizerManager: self.get_internal_state_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - self.set_internal_state_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) self._result_dispatcher = TypeBasedDispatcher( [ @@ -309,10 +304,6 @@ class TokenizerManager: GetInternalStateReqOutput, self.get_internal_state_communicator.handle_recv, ), - ( - SetInternalStateReqOutput, - self.set_internal_state_communicator.handle_recv, - ), (HealthCheckOutput, lambda x: None), ] ) @@ -774,14 +765,6 @@ class TokenizerManager: ) return res[0].internal_state - async def set_internal_state( - self, obj: SetInternalStateReq - ) -> SetInternalStateReqOutput: - res: List[SetInternalStateReqOutput] = ( - await self.set_internal_state_communicator(obj) - ) - return res[0] - def get_log_request_metadata(self): max_length = None skip_names = None diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ddb8a7c2e..1423f253f 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs @@ -49,6 +50,8 @@ class TpModelWorker: dp_rank: Optional[int], nccl_port: int, is_draft_worker: bool = False, + req_to_token_pool: Optional[ReqToTokenPool] = None, + token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, ): # Parse args self.tp_rank = tp_rank @@ -77,6 +80,8 @@ class TpModelWorker: nccl_port=nccl_port, server_args=server_args, is_draft_worker=is_draft_worker, + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=token_to_kv_pool_allocator, ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -154,7 +159,7 @@ class TpModelWorker: def get_memory_pool(self): return ( self.model_runner.req_to_token_pool, - self.model_runner.token_to_kv_pool, + self.model_runner.token_to_kv_pool_allocator, ) def forward_batch_generation( diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 74a2be5a2..aaaa28e22 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -100,7 +100,7 @@ class TpModelWorkerClient: def get_memory_pool(self): return ( self.worker.model_runner.req_to_token_pool, - self.worker.model_runner.token_to_kv_pool, + self.worker.model_runner.token_to_kv_pool_allocator, ) def forward_thread_func(self): diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 8f58e146c..a89fa93a1 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -1,13 +1,12 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" - from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple import torch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -21,11 +20,13 @@ class ChunkCacheEntry: class ChunkCache(BasePrefixCache): def __init__( - self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, ): self.disable = True self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool = token_to_kv_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.entries: Dict[str, ChunkCacheEntry] = {} self.reset() @@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache): req.req_pool_idx, :token_id_len ] self.req_to_token_pool.free(req.req_pool_idx) - self.token_to_kv_pool.free(kv_indices) + self.token_to_kv_pool_allocator.free(kv_indices) if req.rid in self.entries: del self.entries[req.rid] @@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache): def protected_size(self): return 0 + + def pretty_print(self): + return "" diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 4a57eacd1..051f66f77 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -7,8 +7,8 @@ import torch from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.mem_cache.memory_pool import ( - BaseTokenToKVPool, - MLATokenToKVPoolHost, + MHATokenToKVPool, + MHATokenToKVPoolHost, ReqToTokenPool, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match @@ -21,9 +21,9 @@ class HiRadixCache(RadixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool: BaseTokenToKVPool, + token_to_kv_pool: MHATokenToKVPool, ): - self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool) + self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool) self.cache_controller = HiCacheController( token_to_kv_pool, self.token_to_kv_pool_host ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 97af2b386..0d9e6275d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -20,9 +20,12 @@ Memory pool. SGLang has two levels of memory pool. ReqToTokenPool maps a a request to its token locations. -BaseTokenToKVPool maps a token location to its KV cache data. +TokenToKVPoolAllocator maps a token location to its KV cache data. +KVCache actually holds the physical kv cache. Allocation indices are allocated +by TokenToKVPoolAllocator """ +import abc import logging import threading from enum import IntEnum @@ -89,7 +92,7 @@ class ReqToTokenPool: self.free_slots = list(range(self.size)) -class BaseTokenToKVPool: +class TokenToKVPoolAllocator: """A memory pool that maps a token location to its kv cache data.""" def __init__( @@ -100,11 +103,6 @@ class BaseTokenToKVPool: ): self.size = size self.dtype = dtype - if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): - # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 - self.store_dtype = torch.uint8 - else: - self.store_dtype = dtype self.device = device self.free_slots = None @@ -148,15 +146,22 @@ class BaseTokenToKVPool: self.is_in_free_group = False self.free_group = [] + +class KVCache(abc.ABC): + + @abc.abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() + @abc.abstractmethod def get_value_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() + @abc.abstractmethod def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() + @abc.abstractmethod def set_kv_buffer( self, layer: RadixAttention, @@ -167,7 +172,7 @@ class BaseTokenToKVPool: raise NotImplementedError() -class MHATokenToKVPool(BaseTokenToKVPool): +class MHATokenToKVPool(KVCache): def __init__( self, @@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool): device: str, enable_memory_saver: bool, ): - super().__init__(size, dtype, device) - + self.size = size + self.dtype = dtype + self.device = device + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) @@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_2[loc] = src_2.to(dtype).view(store_dtype) -class MLATokenToKVPool(BaseTokenToKVPool): +class MLATokenToKVPool(KVCache): def __init__( self, size: int, @@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool): device: str, enable_memory_saver: bool, ): - super().__init__(size, dtype, device) - + self.size = size + self.dtype = dtype + self.device = device + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype self.kv_lora_rank = kv_lora_rank memory_saver_adapter = TorchMemorySaverAdapter.create( @@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool): self.kv_buffer[layer_id][loc] = cache_k -class DoubleSparseTokenToKVPool(BaseTokenToKVPool): +class DoubleSparseTokenToKVPool(KVCache): def __init__( self, size: int, @@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): heavy_channel_num: int, enable_memory_saver: bool, ): - super().__init__(size, dtype, device) - + self.size = size + self.dtype = dtype + self.device = device + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) @@ -437,12 +460,12 @@ def synchronized(func): return wrapper -class MLATokenToKVPoolHost: +class MHATokenToKVPoolHost: def __init__( self, device_pool: MHATokenToKVPool, - host_to_device_ratio: float = 4.0, + host_to_device_ratio: float = 2.0, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 3bf87b542..c99a47516 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import torch +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool: BaseTokenToKVPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, disable: bool = False, ): self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool = token_to_kv_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.disable = disable self.reset() @@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache): kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, :token_ids_len ] - self.token_to_kv_pool.free(kv_indices) + self.token_to_kv_pool_allocator.free(kv_indices) self.req_to_token_pool.free(req.req_pool_idx) return @@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool new_prefix_len = self.insert(token_ids, kv_indices.clone()) - self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) @@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool new_prefix_len = self.insert(token_ids, kv_indices.clone()) - self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) # The prefix indices could be updated, reuse it new_indices, new_last_node = self.match_prefix(token_ids) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6d0a416bf..1aaae45b1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, + TokenToKVPoolAllocator, ) from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -98,6 +99,8 @@ class ModelRunner: nccl_port: int, server_args: ServerArgs, is_draft_worker: bool = False, + req_to_token_pool: Optional[ReqToTokenPool] = None, + token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, ): # Parse args self.model_config = model_config @@ -115,6 +118,8 @@ class ModelRunner: self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator # Model-specific adjustment if ( @@ -257,8 +262,8 @@ class ModelRunner: def init_torch_distributed(self): logger.info("Init torch distributed begin.") - torch.get_device_module(self.device).set_device(self.gpu_id) + if self.device == "cuda": backend = "nccl" elif self.device == "xpu": @@ -660,12 +665,25 @@ class ModelRunner: if not self.spec_algorithm.is_none(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size + max_num_reqs = self.server_args.max_num_reqs else: + # We are sharing the `token_to_kv_pool`, and both verify and draft tokens + # can be concurrently allocated, so we should give a headroom for it. self.server_args.draft_runner_cache_size = ( self.max_total_num_tokens - + max_num_reqs * self.server_args.speculative_num_steps + # draft + + max_num_reqs + * self.server_args.speculative_num_steps + * self.server_args.speculative_eagle_topk + # verify + + max_num_reqs * self.server_args.speculative_num_draft_tokens + # buffer + 100 ) + # Target worker and draft worker shares the same indices for the + # token_to_kv_pool, so we should make sure to match max_total_num_tokens. + self.max_total_num_tokens = self.server_args.draft_runner_cache_size + self.server_args.max_num_reqs = max_num_reqs if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: @@ -681,12 +699,25 @@ class ModelRunner: "Not enough memory. Please try to increase --mem-fraction-static." ) - self.req_to_token_pool = ReqToTokenPool( - size=max_num_reqs + 1, - max_context_len=self.model_config.context_len + 4, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - ) + if self.req_to_token_pool is None: + self.req_to_token_pool = ReqToTokenPool( + size=max_num_reqs + 1, + max_context_len=self.model_config.context_len + 4, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + ) + else: + # Draft worker shares req_to_token_pool with the target worker. + assert self.is_draft_worker + + if self.token_to_kv_pool_allocator is None: + self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + ) + else: + assert self.is_draft_worker if ( self.model_config.attention_arch == AttentionArch.MLA diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ac433b1eb..9b6a6dbdf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -280,11 +280,16 @@ class ServerArgs: self.disable_overlap_schedule = True self.prefill_only_one_req = True self.disable_cuda_graph_padding = True - self.disable_radix_cache = True - self.chunked_prefill_size = -1 + if self.max_running_requests is None: + self.max_running_requests = 32 logger.info( - f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding." + "Overlap scheduler are disabled because of using " + "eagle speculative decoding." + "Max running request set to 32 because of using eagle speculative decoding." ) + # The token generated from the verify step is counted. + # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. + assert self.speculative_num_steps < self.speculative_num_draft_tokens # GGUF if ( diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 87896cb6c..027838ab1 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -3,14 +3,8 @@ from typing import List import torch - -from sglang.srt.utils import is_cuda_available - -if is_cuda_available(): - from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel - from sgl_kernel import ( - build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, - ) +from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel +from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient def build_tree_kernel_efficient_preprocess( diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 5a79a9809..c3ecb80a4 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.speculative.eagle_utils import EagleDraftInput if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.eagle_worker import EAGLEWorker diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 7ea1ea9b8..17e688085 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,16 +1,17 @@ from __future__ import annotations -import dataclasses -from typing import TYPE_CHECKING, List +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List import torch import torch.nn.functional as F import triton import triton.language as tl -from sglang.srt.layers.attention.flashinfer_backend import ( - create_flashinfer_kv_indices_triton, -) +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import ( build_tree_kernel, @@ -25,7 +26,7 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch -@dataclasses.dataclass +@dataclass class EagleDraftInput: # The inputs for decode # shape: (b, topk) @@ -46,57 +47,46 @@ class EagleDraftInput: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None + # indices of unfinished requests during extend-after-decode + # e.g. [0, 2, 3, 4] if only the 1st request is finished + keep_indices: List[int] = None + def prepare_for_extend(self, batch: ScheduleBatch): - req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) - out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - batch.out_cache_loc = out_cache_loc + assert batch.input_ids.numel() == batch.out_cache_loc.shape[0] + # Prefill only generate 1 token. + assert len(self.verified_id) == len(batch.seq_lens) pt = 0 - for i, req in enumerate(batch.reqs): - req.req_pool_idx = req_pool_indices[i] - pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) - assert seq_len - pre_len == req.extend_input_len - - if pre_len > 0: - batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - :pre_len - ] = req.prefix_indices - - batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( - out_cache_loc[pt : pt + req.extend_input_len] + for i, extend_len in enumerate(batch.extend_lens): + input_ids = batch.input_ids[pt : pt + extend_len] + batch.input_ids[pt : pt + extend_len] = torch.concat( + (input_ids[1:], self.verified_id[i].reshape(1)) ) - pt += req.extend_input_len - - # TODO: support batching inputs - assert len(batch.extend_lens) == 1 - batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) - def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): - batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + assert self.verified_id.numel() == batch.out_cache_loc.shape[0] accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.extend_num_tokens = sum(batch.extend_lens) batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend - batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend seq_lens_cpu = batch.seq_lens.tolist() + assert len(batch.req_pool_indices) == len(batch.reqs) pt = 0 i = 0 - for req in batch.reqs: + self.keep_indices = [] + for idx, req in enumerate(batch.reqs): if req.finished(): continue + self.keep_indices.append(idx) # assert seq_len - pre_len == req.extend_input_len input_len = batch.extend_lens[i] seq_len = seq_lens_cpu[i] - batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - seq_len - input_len : seq_len - ] = batch.out_cache_loc[pt : pt + input_len] pt += input_len i += 1 - assert pt == batch.out_cache_loc.shape[0] - self.positions = torch.empty_like(self.verified_id) - new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.positions = torch.empty_like(self.verified_id, dtype=torch.long) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) self.accept_length.add_(1) create_extend_spec_info[(self.accept_length.numel(),)]( @@ -117,14 +107,22 @@ class EagleDraftInput: self, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, req_to_token: torch.Tensor, ): bs = self.accept_length.numel() + keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device) + req_pool_indices = req_pool_indices[keep_indices] + assert req_pool_indices.shape[0] == bs + assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0] + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + # TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync. kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( @@ -162,7 +160,21 @@ class EagleDraftInput: self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) -@dataclasses.dataclass +@dataclass +class EagleVerifyOutput: + # Draft input batch + draft_input: EagleDraftInput + # Logit outputs from target worker + logits_output: LogitsProcessorOutput + # Accepeted token ids including the bonus token + verified_id: torch.Tensor + # Accepeted token length per sequence in a batch in CPU. + accept_length_per_req_cpu: List[int] + # Accepeted indices from logits_output.next_token_logits + accepeted_indices_cpu: List[int] + + +@dataclass class EagleVerifyInput: draft_token: torch.Tensor custom_mask: torch.Tensor @@ -267,6 +279,7 @@ class EagleVerifyInput: self, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, req_to_token: torch.Tensor, ): batch_size = len(req_pool_indices) @@ -285,7 +298,11 @@ class EagleVerifyInput: paged_kernel_lens = paged_kernel_lens + self.draft_token_num cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + kv_indices = torch.empty( + paged_kernel_lens_sum + self.draft_token_num * batch_size, + dtype=torch.int32, + device="cuda", + ) create_flashinfer_kv_indices_triton[(batch_size,)]( req_to_token, @@ -298,7 +315,21 @@ class EagleVerifyInput: ) return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask - def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + def verify( + self, + batch: ScheduleBatch, + logits_output: torch.Tensor, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + ) -> torch.Tensor: + """WARNING: This API in-place modifies the states of logits_output + + Verify and find accepted tokens based on logits output and batch + (which contains spec decoding information). + + This API updates values inside logits_output based on the accepted + tokens. I.e., logits_output.next_token_logits only contains + accepeted token logits. + """ draft_token = torch.cat( [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1, @@ -367,7 +398,6 @@ class EagleVerifyInput: new_accept_index = [] unfinished_index = [] - finished_extend_len = {} # {rid:accept_length + 1} accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() has_finished = False @@ -382,7 +412,6 @@ class EagleVerifyInput: id = predict_cpu[idx] # if not found_finished: req.output_ids.append(id) - finished_extend_len[req.rid] = j + 1 req.check_finished() if req.finished(): has_finished = True @@ -400,11 +429,10 @@ class EagleVerifyInput: accept_index = accept_index[accept_index != -1] accept_length_cpu = accept_length.tolist() verified_id = predict[accept_index] - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False mem_need_free_idx = batch.out_cache_loc[evict_mask] - batch.token_to_kv_pool.free(mem_need_free_idx) + token_to_kv_pool_allocator.free(mem_need_free_idx) assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -427,20 +455,16 @@ class EagleVerifyInput: ] if has_finished: draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ - unfinished_index - ] else: draft_input.seq_lens_for_draft_extend = batch.seq_lens - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices + batch.out_cache_loc = batch.out_cache_loc[new_accept_index] - logits_output.next_token_logits = logits_output.next_token_logits[accept_index] - return ( - draft_input, - logits_output, - verified_id, - finished_extend_len, - accept_length_cpu, + return EagleVerifyOutput( + draft_input=draft_input, + logits_output=logits_output, + verified_id=verified_id, + accept_length_per_req_cpu=accept_length_cpu, + accepeted_indices_cpu=accept_index, ) @@ -456,6 +480,18 @@ def eagle_verify_retrive( draft_token_num: tl.constexpr, max_len_upper: tl.constexpr, ): + """ + Args: + retrive_index: Pointer to indices of draft tokens + accept_mask: Mask indicating which tokens were accepted + retrive_cum_len: Cumulative lengths of token sequences in a batch + accept_index (out): Accept token indices + accept_length (out): Length of accepted tokens per sequence in a batch + extract_index (out): Index for last accepted tokens + max_len: Maximum length in a batch + draft_token_num: Number of tokens speculatively generated + max_len_upper An upper bound for token sequence length + """ pid = tl.program_id(axis=0) retrive_end = tl.load(retrive_cum_len + pid + 1) @@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices( tl.store(kv_indptr + zid, base + zid * iters) -@torch.compile +@torch.compile(dynamic=True) def select_top_k_tokens( i: int, topk_p: torch.Tensor, @@ -671,13 +707,11 @@ def select_top_k_tokens( .unsqueeze(0) .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) ) - else: # The later decode steps expand_scores = torch.mul( scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) - topk_cs_p, topk_cs_index = fast_topk( expand_scores.flatten(start_dim=1), topk, dim=-1 ) # (b, topk) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index eac8cf891..4dce896c0 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from huggingface_hub import snapshot_download @@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_utils import ( EagleDraftInput, EagleVerifyInput, + EagleVerifyOutput, assign_draft_cache_locs, fast_topk, select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import get_available_gpu_memory logger = logging.getLogger(__name__) @@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker): nccl_port: int, target_worker: TpModelWorker, ): + # Override context length with target model's context length + server_args.context_length = target_worker.model_runner.model_config.context_len + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + # Do not capture cuda graph in `super().__init__()` # We will capture it later backup_disable_cuda_graph = server_args.disable_cuda_graph server_args.disable_cuda_graph = True - # Load hot token ids + # Lossy optimization by using hot tokens if server_args.speculative_token_map is not None: self.hot_token_id = load_token_map(server_args.speculative_token_map) server_args.json_model_override_args = ( @@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker): else: self.hot_token_id = None + # We share the allocator with a target worker. Draft/target worker + # owns its own KV cache. + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + # Init target worker super().__init__( gpu_id=gpu_id, @@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker): nccl_port=nccl_port, dp_rank=dp_rank, is_draft_worker=True, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) self.target_worker = target_worker - self.finish_extend_len = [] # Parse arguments self.topk = server_args.speculative_eagle_topk @@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker): server_args.speculative_algorithm ) self.server_args = server_args + self.use_nan_detection = self.server_args.enable_nan_detection + self.device = self.model_runner.device + self.gpu_id = self.model_runner.gpu_id # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() @@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker): head = head.clone() self.hot_token_id = self.hot_token_id.to(head.device) head.data = head.data[self.hot_token_id] - self.model_runner.model.set_embed_and_head(embed, head) - self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph + self.draft_model_runner.model.set_embed_and_head(embed, head) + self.draft_model_runner.server_args.disable_cuda_graph = ( + backup_disable_cuda_graph + ) # Create multi-step attn backends and cuda graph runners if server_args.attention_backend == "flashinfer": @@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker): f"EAGLE is not supportted in attention backend {server_args.attention_backend}" ) - self.model_runner.draft_attn_backend = self.draft_attn_backend + self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.init_cuda_graphs() def init_cuda_graphs(self): @@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker): return tic = time.time() - logger.info("Capture cuda graph begin. This can take up to several minutes.") + logger.info( + f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) - logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") + logger.info( + f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) - def forward_batch_speculative_generation(self, batch: ScheduleBatch): + @property + def draft_model_runner(self): + return self.model_runner + + def forward_batch_speculative_generation( + self, batch: ScheduleBatch + ) -> Tuple[LogitsProcessorOutput, List[int], int, int]: + """Run speculative decoding forward. + + NOTE: Many states of batch is modified as you go through. It is not guaranteed + the final output batch doesn't have the same state as the input. + + Args: + batch: The batch to run forward. The state of the batch is modified as it runs. + Returns: + A tuple of the final logit output of the target model, next tokens accepeted, + the batch id (used for overlap schedule), and number of accepeted tokens. + """ + assert not batch.spec_algorithm.is_none() if batch.forward_mode.is_decode(): - # Draft - spec_info: EagleVerifyInput = self.draft(batch) - - # Verify - ( - next_draft_input, - logits_output, - verified_id, - self.finish_extend_len, - accept_length_cpu, - model_worker_batch, - ) = self.verify(batch, spec_info) - batch.spec_info = next_draft_input - # if it is None, means all requsets are finished + spec_info, to_free_cache_loc = self.draft(batch) + logits_output, verify_output, model_worker_batch = self.verify( + batch, spec_info + ) + # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.) + self.token_to_kv_pool_allocator.free(to_free_cache_loc) + # if it is None, means all requests are finished if batch.spec_info.verified_id is not None: self.forward_draft_extend_after_decode(batch) + return ( logits_output, - verified_id, - model_worker_batch, - sum(accept_length_cpu), + verify_output.verified_id, + model_worker_batch.bid, + sum(verify_output.accept_length_per_req_cpu), ) else: - # Forward with the target model and get hidden states. - # We need the full hidden states to prefill the KV cache of the draft model. - model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - logits_output, next_token_ids = self.target_worker.forward_batch_generation( - model_worker_batch + logits_output, next_token_ids, bid = self.forward_target_extend(batch) + self.forward_draft_extend( + batch, logits_output.hidden_states, next_token_ids ) + return logits_output, next_token_ids, bid, 0 - # Forward with the draft model. - batch.spec_info = EagleDraftInput( - hidden_states=logits_output.hidden_states, - verified_id=next_token_ids, - ) - self.forward_draft_extend(batch) - return logits_output, next_token_ids, model_worker_batch, 0 + def forward_target_extend( + self, batch: ScheduleBatch + ) -> Tuple[LogitsProcessorOutput, List[int], int]: + """Run the target extend. + + Args: + batch: The batch to run. States could be modified. + + Returns: + logits_output: The output of logits. It will contain the full hidden states. + next_token_ids: Next token ids generated. + bid: The model batch ID. Used for overlap schedule. + """ + # Forward with the target model and get hidden states. + # We need the full hidden states to prefill the KV cache of the draft model. + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + logits_output, next_token_ids = self.target_worker.forward_batch_generation( + model_worker_batch + ) + return logits_output, next_token_ids, model_worker_batch.bid def draft(self, batch: ScheduleBatch): - self._set_mem_pool(batch, self.model_runner) - # Parse args num_seqs = batch.batch_size() spec_info = batch.spec_info @@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) - batch.out_cache_loc = out_cache_loc batch.seq_lens_sum = torch.sum(batch.seq_lens).item() spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) @@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker): # Get forward batch spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.draft_model_runner + ) can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( forward_batch ) - if can_cuda_graph: score_list, token_list, parents_list = self.cuda_graph_runner.replay( forward_batch @@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker): else: # Initialize attention backend self.draft_attn_backend.init_forward_metadata(forward_batch) - + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.draft_model_runner + ) # Run forward steps score_list, token_list, parents_list = self.draft_forward(forward_batch) @@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker): batch.sampling_info.is_all_greedy, ) - # Free cache locations - batch.token_to_kv_pool.free(out_cache_loc) - self._set_mem_pool(batch, self.target_worker.model_runner) - return ret + return ret, out_cache_loc def draft_forward(self, forward_batch: ForwardBatch): # Parse args @@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker): logits_output = self.model_runner.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) + self._detect_nan_if_needed(logits_output) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) if self.hot_token_id is not None: @@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker): logits_output, _ = self.target_worker.forward_batch_generation( model_worker_batch, skip_sample=True ) + self._detect_nan_if_needed(logits_output) spec_info.hidden_states = logits_output.hidden_states - res = spec_info.verify(batch, logits_output) - batch.forward_mode = ForwardMode.DECODE - return res + (model_worker_batch,) + res: EagleVerifyOutput = spec_info.verify( + batch, logits_output, self.token_to_kv_pool_allocator + ) - def forward_draft_extend(self, batch: ScheduleBatch): - self._set_mem_pool(batch, self.model_runner) + # Post process based on verified outputs. + # Pick indices that we care (accepeted) + logits_output.next_token_logits = logits_output.next_token_logits[ + res.accepeted_indices_cpu + ] + logits_output.hidden_states = logits_output.hidden_states[ + res.accepeted_indices_cpu + ] + # Prepare the batch for the next draft forwards. + batch.forward_mode = ForwardMode.DECODE + batch.spec_info = res.draft_input + + return logits_output, res, model_worker_batch + + def forward_draft_extend( + self, + batch: ScheduleBatch, + hidden_states: torch.Tensor, + next_token_ids: List[int], + ): + """Run draft model extend. This API modifies the states of the batch. + + Args: + batch: The batch to run. + hidden_states: Hidden states from the target model forward + next_token_ids: Next token ids generated from the target forward. + """ + batch.spec_info = EagleDraftInput( + hidden_states=hidden_states, + verified_id=next_token_ids, + ) batch.spec_info.prepare_for_extend(batch) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) - self._set_mem_pool(batch, self.target_worker.model_runner) - - def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): - batch.token_to_kv_pool = runner.token_to_kv_pool - batch.req_to_token_pool = runner.req_to_token_pool + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.draft_model_runner + ) + logits_output = self.draft_model_runner.forward(forward_batch) + self._detect_nan_if_needed(logits_output) + assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info is batch.spec_info + self.capture_for_decode(logits_output, forward_batch.spec_info) def forward_draft_extend_after_decode(self, batch: ScheduleBatch): seq_lens_backup = batch.seq_lens - req_pool_indices_backup = batch.req_pool_indices - - self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + # We don't need logprob for this extend. model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.draft_model_runner + ) + logits_output = self.draft_model_runner.forward(forward_batch) + self._detect_nan_if_needed(logits_output) + assert forward_batch.spec_info is batch.spec_info + self.capture_for_decode(logits_output, forward_batch.spec_info) # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` batch.forward_mode = ForwardMode.DECODE batch.seq_lens = seq_lens_backup - batch.req_pool_indices = req_pool_indices_backup def capture_for_decode( - self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch + self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput ): probs = torch.softmax(logits_output.next_token_logits, dim=-1) - spec_info = forward_batch.spec_info - spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1) - spec_info.hidden_states = logits_output.hidden_states + draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1) + draft_input.hidden_states = logits_output.hidden_states - # Don't support prefix share now. - def finish_request(self, reqs: Union[Req, List[Req]]): - if not isinstance(reqs, List): - reqs = [reqs] - for req in reqs: - if req.rid not in self.finish_extend_len: - continue - req_len = ( - len(req.origin_input_ids) - + len(req.output_ids) - - self.finish_extend_len[req.rid] - - 1 - ) - kv_indices = self.model_runner.req_to_token_pool.req_to_token[ - req.req_pool_idx - ][:req_len] - self.model_runner.token_to_kv_pool.free(kv_indices) - self.model_runner.req_to_token_pool.free(req.req_pool_idx) + def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): + if self.use_nan_detection: + logits = logits_output.next_token_logits + if torch.any(torch.isnan(logits)): + logger.warning("Detected errors during sampling! NaN in the logits.") + raise ValueError("Detected errors during sampling! NaN in the logits.") def load_token_map(token_map_path: str) -> List[int]: diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index af45ac423..4eead0c6b 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -20,7 +20,3 @@ class SpeculativeAlgorithm(IntEnum): if name is not None: name = name.upper() return name_map[name] - - -class SpecInfo: - pass diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 2347c3a1e..9571faf22 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import random import threading import time @@ -18,6 +19,8 @@ from sglang.test.test_utils import ( popen_launch_server, ) +acc_rate_tolerance = 0.15 + class TestEAGLEEngine(unittest.TestCase): BASE_CONFIG = { @@ -43,13 +46,19 @@ class TestEAGLEEngine(unittest.TestCase): configs = [ self.BASE_CONFIG, {**self.BASE_CONFIG, "disable_cuda_graph": True}, + {**self.BASE_CONFIG, "chunked_prefill_size": 2}, ] for config in configs: with self.subTest( cuda_graph=( "enabled" if len(config) == len(self.BASE_CONFIG) else "disabled" - ) + ), + chunked_prefill_size=( + config["chunked_prefill_size"] + if "chunked_prefill_size" in config + else "default" + ), ): engine = sgl.Engine(**config) try: @@ -125,6 +134,8 @@ class TestEAGLEServer(unittest.TestCase): "64", "--mem-fraction-static", "0.7", + "--chunked-prefill-size", + "128", "--cuda-graph-max-bs", "32", ], @@ -196,6 +207,137 @@ class TestEAGLEServer(unittest.TestCase): self.assertGreater(metrics["accuracy"], 0.20) +def measure_acc_rate(engine): + tic = time.time() + prompt = [ + "Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:" + ] + sampling_params = {"temperature": 0, "max_new_tokens": 512} + output = engine.generate(prompt, sampling_params) + output = output[0] + latency = time.time() - tic + + if "spec_verify_ct" in output["meta_info"]: + base_acc_length = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["spec_verify_ct"] + ) + else: + base_acc_length = 0.0 + + base_speed = output["meta_info"]["completion_tokens"] / latency + return base_acc_length, base_speed + + +class TestEagleAcceptanceRate(unittest.TestCase): + + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + ref_engine = sgl.Engine( + model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + speculative_algorithm="EAGLE", + speculative_num_steps=5, + speculative_eagle_topk=8, + speculative_num_draft_tokens=64, + mem_fraction_static=0.7, + disable_radix_cache=True, + ) + cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine) + ref_engine.shutdown() + assert cls.base_acc_length > 4.45 + + def test_acc_rate(self): + base_acc_length, base_speed = self.base_acc_length, self.base_speed + chunk_engine = sgl.Engine( + model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + speculative_algorithm="EAGLE", + speculative_num_steps=5, + speculative_eagle_topk=8, + speculative_num_draft_tokens=64, + mem_fraction_static=0.7, + chunked_prefill_size=2, + disable_radix_cache=True, + ) + chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine) + chunk_engine.shutdown() + print(base_acc_length, base_speed) + print(chunked_acc_length, chunked_base_speed) + assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance + + def test_acc_rate_prefix_caching(self): + base_acc_length, base_speed = self.base_acc_length, self.base_speed + prefix_caching_engine = sgl.Engine( + model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + speculative_algorithm="EAGLE", + speculative_num_steps=5, + speculative_eagle_topk=8, + speculative_num_draft_tokens=64, + mem_fraction_static=0.7, + chunked_prefill_size=4, + schedule_policy="lpm", + ) + for _ in range(10): + acc_length, _ = measure_acc_rate(prefix_caching_engine) + print(f"{acc_length=}") + assert abs(base_acc_length - acc_length) < acc_rate_tolerance + # The second one should hit the prefix cache. + prefix_caching_engine.shutdown() + + +class TestEAGLERetract(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + "--chunked-prefill-size", + "128", + "--max-running-requests", + "64", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + + self.assertGreater(metrics["accuracy"], 0.20) + # Wait a little bit so that the memory check happens. + time.sleep(5) + + class TestEAGLEServerTriton(TestEAGLEServer): @classmethod def setUpClass(cls):