diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 7a6219527..3c60baf0a 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -36,7 +36,7 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn _is_cuda = is_cuda() if _is_cuda: - import deep_gemm + import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"` from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f4667f574..1f00bd646 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -49,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_compiler_backend, next_power_of_2 if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -273,7 +274,6 @@ class Req: "__req__": self } self.sampling_params = sampling_params - self.custom_logit_processor = custom_logit_processor self.return_hidden_states = return_hidden_states @@ -331,6 +331,8 @@ class Req: self.logprob_start_len = 0 self.top_logprobs_num = top_logprobs_num self.token_ids_logprob = token_ids_logprob + self.temp_scaled_logprobs = False + self.top_p_normalized_logprobs = False # Logprobs (return values) self.input_token_logprobs_val: Optional[List[float]] = None @@ -524,19 +526,23 @@ class ScheduleBatch: model_config: ModelConfig = None forward_mode: ForwardMode = None enable_overlap: bool = False + # Tell whether the current running batch is full so that we can skip + # the check of whether to prefill new requests. + # This is an optimization to reduce the overhead of the prefill check. + batch_is_full: bool = False # Sampling info sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner - input_ids: torch.Tensor = None # shape: [b], int32 + input_ids: torch.Tensor = None # shape: [b], int64 input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 - req_pool_indices: torch.Tensor = None # shape: [b], int32 + req_pool_indices: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache - out_cache_loc: torch.Tensor = None # shape: [b], int32 - output_ids: torch.Tensor = None # shape: [b], int32 + out_cache_loc: torch.Tensor = None # shape: [b], int64 + output_ids: torch.Tensor = None # shape: [b], int64 # The sum of all sequence lengths seq_lens_sum: int = None @@ -551,6 +557,10 @@ class ScheduleBatch: top_logprobs_nums: Optional[List[int]] = None token_ids_logprobs: Optional[List[List[int]]] = None + # For logits and logprob post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False + # For extend and mixed chunekd prefill prefix_lens: List[int] = None extend_lens: List[int] = None @@ -560,7 +570,7 @@ class ScheduleBatch: # It comes empty list if logprob is not required. extend_input_logprob_token_ids: Optional[torch.Tensor] = None - # For encoder-decoder + # For encoder-decoder architectures encoder_cached: Optional[List[bool]] = None encoder_lens: Optional[torch.Tensor] = None encoder_lens_cpu: Optional[List[int]] = None @@ -597,6 +607,8 @@ class ScheduleBatch: spec_algorithm: SpeculativeAlgorithm, enable_custom_logit_processor: bool, ): + return_logprob = any(req.return_logprob for req in reqs) + return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, @@ -604,7 +616,7 @@ class ScheduleBatch: tree_cache=tree_cache, model_config=model_config, enable_overlap=enable_overlap, - return_logprob=any(req.return_logprob for req in reqs), + return_logprob=return_logprob, has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, @@ -631,24 +643,83 @@ class ScheduleBatch: return req_pool_indices def alloc_token_slots(self, num_tokens: int): + if self.token_to_kv_pool_allocator.available_size() < num_tokens: + if self.tree_cache is not None: + self.tree_cache.evict(num_tokens) + out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) + if out_cache_loc is None: + phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" + error_msg = ( + f"{phase_str} out of memory. Try to lower your batch size.\n" + f"Try to allocate {num_tokens} tokens.\n" + f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" + ) + logger.error(error_msg) + if self.tree_cache is not None: + self.tree_cache.pretty_print() + raise RuntimeError(error_msg) + + return out_cache_loc + + def alloc_paged_token_slots_extend( + self, + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + if ( + self.token_to_kv_pool_allocator.available_size() + < extend_num_tokens + + len(seq_lens) * self.token_to_kv_pool_allocator.page_size + ): + if self.tree_cache is not None: + self.tree_cache.evict( + extend_num_tokens + + len(seq_lens) * self.token_to_kv_pool_allocator.page_size, + ) + + out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( + prefix_lens, seq_lens, last_loc, extend_num_tokens + ) + if out_cache_loc is None: + error_msg = ( + f"Prefill out of memory. Try to lower your batch size.\n" + f"Try to allocate {extend_num_tokens} tokens.\n" + f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" + f"{self.token_to_kv_pool_allocator.available_size()=}\n" + f"{self.tree_cache.evictable_size()=}\n" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + return out_cache_loc + + def alloc_paged_token_slots_decode( + self, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + ): + if ( + self.token_to_kv_pool_allocator.available_size() + < len(seq_lens) * self.token_to_kv_pool_allocator.page_size + ): + if self.tree_cache is not None: + self.tree_cache.evict( + len(seq_lens) * self.token_to_kv_pool_allocator.page_size, + ) + out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) if out_cache_loc is None: - if self.tree_cache is not None: - self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free) - out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) - - if out_cache_loc is None: - phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" - logger.error( - f"{phase_str} out of memory. Try to lower your batch size.\n" - f"Try to allocate {num_tokens} tokens.\n" - f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" - ) - if self.tree_cache is not None: - self.tree_cache.pretty_print() - exit(1) - + error_msg = ( + f"Decode out of memory. Try to lower your batch size.\n" + f"Try to allocate {len(seq_lens)} tokens.\n" + f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" + f"{self.token_to_kv_pool_allocator.available_size()=}\n" + f"{self.tree_cache.evictable_size()=}\n" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) return out_cache_loc def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): @@ -699,7 +770,7 @@ class ScheduleBatch: pt += req.extend_input_len # Reassign - self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to( self.device, non_blocking=True ) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( @@ -707,14 +778,14 @@ class ScheduleBatch: ) if not decoder_out_cache_loc: - self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to( + self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to( self.device, non_blocking=True ) else: self.out_cache_loc = torch.cat(decoder_out_cache_loc) if not encoder_out_cache_loc: - self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to( + self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to( self.device, non_blocking=True ) else: @@ -725,25 +796,38 @@ class ScheduleBatch: def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND + # Allocate req slots bs = len(self.reqs) + req_pool_indices = self.alloc_req_slots(bs) + + # Init tensors reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) - seq_lens = [] - pre_lens = [] + seq_lens = [len(r.fill_ids) for r in reqs] + prefix_lens = [len(r.prefix_indices) for r in reqs] + extend_lens = [r.extend_input_len for r in reqs] - # Allocate memory - req_pool_indices = self.alloc_req_slots(bs) - out_cache_loc = self.alloc_token_slots(extend_num_tokens) + req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( + self.device, non_blocking=True + ) + input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to( + self.device, non_blocking=True + ) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( + self.device, non_blocking=True + ) + prefix_lens_tensor = torch.tensor( + prefix_lens, dtype=torch.int64, device=self.device + ) + extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor + # Copy prefix and do some basic check input_embeds = [] extend_input_logprob_token_ids = [] - pt = 0 - for i, req in enumerate(reqs): + for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): req.req_pool_idx = req_pool_indices[i] - pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) - seq_lens.append(seq_len) assert seq_len - pre_len == req.extend_input_len if pre_len > 0: @@ -759,7 +843,7 @@ class ScheduleBatch: req.cached_tokens += pre_len - req.already_computed req.already_computed = seq_len req.is_retracted = False - pre_lens.append(pre_len) + # Compute the relative logprob_start_len in an extend batch if req.logprob_start_len >= pre_len: req.extend_logprob_start_len = min( @@ -815,60 +899,62 @@ class ScheduleBatch: else: extend_input_logprob_token_ids = None + # Allocate memory + if self.token_to_kv_pool_allocator.page_size == 1: + out_cache_loc = self.alloc_token_slots(extend_num_tokens) + else: + last_loc = get_last_loc( + self.req_to_token_pool.req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + ) + out_cache_loc = self.alloc_paged_token_slots_extend( + prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens + ) + # Set fields - self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( - self.device, non_blocking=True - ) - self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to( - self.device, non_blocking=True - ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( - self.device, non_blocking=True - ) + self.input_ids = input_ids_tensor + self.req_pool_indices = req_pool_indices_tensor + self.seq_lens = seq_lens_tensor + self.out_cache_loc = out_cache_loc self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) if input_embeds else None ) - - self.out_cache_loc = out_cache_loc - self.seq_lens_sum = sum(seq_lens) + if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] - self.extend_num_tokens = extend_num_tokens - self.prefix_lens = [len(r.prefix_indices) for r in reqs] - self.extend_lens = [r.extend_input_len for r in reqs] + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + self.extend_num_tokens = extend_num_tokens + self.prefix_lens = prefix_lens + self.extend_lens = extend_lens self.extend_input_logprob_token_ids = extend_input_logprob_token_ids # Write to req_to_token_pool - pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to( - self.device, non_blocking=True - ) - extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to( - self.device, non_blocking=True - ) if global_server_args_dict["attention_backend"] != "torch_native": + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) + write_req_to_token_pool_triton[(bs,)]( self.req_to_token_pool.req_to_token, - self.req_pool_indices, - pre_lens, - self.seq_lens, - extend_lens, - self.out_cache_loc, + req_pool_indices_tensor, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + out_cache_loc, self.req_to_token_pool.req_to_token.shape[1], ) else: pt = 0 for i in range(bs): self.req_to_token_pool.write( - (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])), - self.out_cache_loc[pt : pt + self.extend_lens[i]], + (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), + out_cache_loc[pt : pt + extend_lens[i]], ) - pt += self.extend_lens[i] - # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) + pt += extend_lens[i] if self.model_config.is_encoder_decoder: self.prepare_encoder_info_extend(input_ids, seq_lens) @@ -914,7 +1000,7 @@ class ScheduleBatch: if self.token_to_kv_pool_allocator.available_size() >= bs: return True - self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free) + self.tree_cache.evict(bs) if self.token_to_kv_pool_allocator.available_size() >= bs: return True @@ -939,10 +1025,6 @@ class ScheduleBatch: reverse=True, ) - retracted_reqs = [] - seq_lens_cpu = self.seq_lens.cpu().numpy() - first_iter = True - def get_required_tokens(num_reqs: int): headroom_for_spec_decode = 0 if server_args.speculative_algorithm: @@ -956,6 +1038,9 @@ class ScheduleBatch: num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode ) + retracted_reqs = [] + seq_lens_cpu = self.seq_lens.cpu().numpy() + first_iter = True while ( self.token_to_kv_pool_allocator.available_size() < get_required_tokens(len(sorted_indices)) @@ -980,7 +1065,6 @@ class ScheduleBatch: ] self.token_to_kv_pool_allocator.free(token_indices) self.req_to_token_pool.free(req.req_pool_idx) - del self.tree_cache.entries[req.rid] else: # TODO: apply more fine-grained retraction last_uncached_pos = len(req.prefix_indices) @@ -999,9 +1083,7 @@ class ScheduleBatch: - self.token_to_kv_pool_allocator.available_size() ) residual_size = max(0, residual_size) - self.tree_cache.evict( - residual_size, self.token_to_kv_pool_allocator.free - ) + self.tree_cache.evict(residual_size) req.reset_for_retract() @@ -1024,9 +1106,9 @@ class ScheduleBatch: def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE - self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) + self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) - self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) + self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 @@ -1037,6 +1119,8 @@ class ScheduleBatch: def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE + bs = len(self.reqs) + if self.spec_algorithm.is_eagle(): # if spec decoding is used, the decode batch is prepared inside # `forward_batch_speculative_generation` after running draft models. @@ -1065,33 +1149,39 @@ class ScheduleBatch: self.output_ids.to(torch.int64) ) + # Update fields self.input_ids = self.output_ids self.output_ids = None - # Alloc mem - bs = len(self.reqs) - self.out_cache_loc = self.alloc_token_slots(bs) - if self.model_config.is_encoder_decoder: locs = self.encoder_lens + self.seq_lens self.prepare_encoder_info_decode() else: - locs = self.seq_lens + locs = self.seq_lens.clone() if self.enable_overlap: # Do not use in-place operations in the overlap mode - self.req_to_token_pool.write( - (self.req_pool_indices, locs), self.out_cache_loc - ) self.seq_lens = self.seq_lens + 1 else: # A faster in-place version - self.req_to_token_pool.write( - (self.req_pool_indices, locs), self.out_cache_loc - ) self.seq_lens.add_(1) self.seq_lens_sum += bs + # Allocate memory + if self.token_to_kv_pool_allocator.page_size == 1: + self.out_cache_loc = self.alloc_token_slots(bs) + else: + last_loc = self.req_to_token_pool.req_to_token[ + self.req_pool_indices, self.seq_lens - 2 + ] + self.out_cache_loc = self.alloc_paged_token_slots_decode( + self.seq_lens, last_loc + ) + + self.req_to_token_pool.write( + (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32) + ) + def filter_batch( self, chunked_req_to_exclude: Optional[Req] = None, @@ -1345,8 +1435,8 @@ def write_req_to_token_pool_triton( pre_len = tl.load(pre_lens + pid) seq_len = tl.load(seq_lens + pid) - # TODO: optimize this? - cumsum_start = 0 + # NOTE: This can be slow for large bs + cumsum_start = tl.cast(0, tl.int64) for i in range(pid): cumsum_start += tl.load(extend_lens + i) @@ -1363,3 +1453,12 @@ def write_req_to_token_pool_triton( value, mask=mask, ) + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor): + return torch.where( + prefix_lens_tensor > 0, + req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], + torch.full_like(prefix_lens_tensor, -1), + ) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 3f569088b..8922e6e9d 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -77,7 +77,7 @@ class SchedulePolicy: self, policy: str, tree_cache: BasePrefixCache, - enable_hierarchical_cache: bool = False, + enable_hierarchical_cache: bool, ): self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.tree_cache = tree_cache @@ -85,10 +85,17 @@ class SchedulePolicy: # It is used to find the matching prefix for in-batch prefix caching. self.waiting_queue_radix_tree = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=1, + disable=False, ) def calc_priority(self, waiting_queue: List[Req]) -> bool: + if self.policy == CacheAgnosticPolicy.FCFS: + # A shortcut for FCFS + return + policy = self._determine_active_policy(waiting_queue) prefix_computed = False @@ -118,7 +125,7 @@ class SchedulePolicy: return prefix_computed def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: - if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM: + if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128: # Turn off the expensive prefix matching and sorting when the #queue is large. return CacheAgnosticPolicy.FCFS return self.policy @@ -442,7 +449,7 @@ class PrefillAdder: def add_one_req( self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False ): - if req.sampling_params.ignore_eos and self.tree_cache.disable: + if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): return self.add_one_req_ignore_eos(req, has_chunked_req) total_tokens = req.extend_input_len + min( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 445edb10f..68bef7e08 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -103,6 +103,7 @@ from sglang.srt.utils import ( crash_on_warnings, get_bool_env_var, get_zmq_socket, + kill_itself_when_parent_died, pyspy_dump_schedulers, set_gpu_proc_affinity, set_random_seed, @@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ) self.gpu_id = gpu_id self.enable_hierarchical_cache = server_args.enable_hierarchical_cache + self.page_size = server_args.page_size # Distributed rank info self.dp_size = server_args.dp_size @@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin): f"context_len={self.model_config.context_len}" ) + # Init memory pool and cache self.init_memory_pool_and_cache() # Init running status self.waiting_queue: List[Req] = [] # The running decoding batch for continuous batching - self.running_batch: Optional[ScheduleBatch] = None + self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False) # The current forward batch self.cur_batch: Optional[ScheduleBatch] = None - # The current forward batch + # The last forward batch self.last_batch: Optional[ScheduleBatch] = None self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 + self.num_prefill_tokens = 0 self.last_decode_stats_tic = time.time() + self.last_prefill_stats_tic = time.time() self.return_health_check_ct = 0 self.current_stream = torch.get_device_module(self.device).current_stream() if self.device == "cpu": @@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin): # Init schedule policy and new token estimation self.policy = SchedulePolicy( - self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache + self.schedule_policy, + self.tree_cache, + self.enable_hierarchical_cache, ) assert ( server_args.schedule_conservativeness >= 0 @@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio - # Tell whether the current running batch is full so that we can skip - # the check of whether to prefill new requests. - # This is an optimization to reduce the overhead of the prefill check. - self.batch_is_full = False - # Init watchdog thread self.watchdog_timeout = server_args.watchdog_timeout t = threading.Thread(target=self.watchdog_thread, daemon=True) @@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin): self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, disable=server_args.disable_radix_cache, ) @@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin): # The largest context length (prefill + generation) of a single request self._largest_prefill_decode_len: int = 0 self.last_gen_throughput: float = 0.0 + self.last_input_throughput: float = 0.0 self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] self.spec_num_total_accepted_tokens = 0 self.spec_num_total_forward_ct = 0 @@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin): result = self.run_batch(batch) self.process_batch_result(batch, result) else: - # When the server is idle, so self-check and re-init some states + # When the server is idle, do self-check and re-init some states self.check_memory() self.new_token_ratio = self.init_new_token_ratio @@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ) self.process_batch_result(tmp_batch, tmp_result) elif batch is None: - # When the server is idle, so self-check and re-init some states + # When the server is idle, do self-check and re-init some states self.check_memory() self.new_token_ratio = self.init_new_token_ratio @@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin): for recv_req in recv_reqs: # If it is a health check generation request and there are running requests, ignore it. if is_health_check_generate_req(recv_req) and ( - self.chunked_req is not None or self.running_batch is not None + self.chunked_req is not None or not self.running_batch.is_empty() ): self.return_health_check_ct += 1 continue @@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin): can_run_list: List[Req], running_bs: int, ): + gap_latency = time.time() - self.last_prefill_stats_tic + self.last_prefill_stats_tic = time.time() + self.last_input_throughput = self.num_prefill_tokens / gap_latency + self.num_prefill_tokens = 0 + num_used = self.max_total_num_tokens - ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() @@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin): self.last_decode_stats_tic = time.time() self.last_gen_throughput = self.num_generated_tokens / gap_latency self.num_generated_tokens = 0 - num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 + num_running_reqs = len(self.running_batch.reqs) num_used = self.max_total_num_tokens - ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() @@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin): ) if memory_leak: msg = ( - "KV cache pool leak detected!" + "KV cache pool leak detected! " f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" + f"{self.token_to_kv_pool_allocator.available_size()=}\n" + f"{self.tree_cache.evictable_size()=}\n" ) warnings.warn(msg) if crash_on_warnings(): @@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin): self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() ) - num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 + num_running_reqs = len(self.running_batch.reqs) self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used self.stats.token_usage = num_used / self.max_total_num_tokens @@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin): self.tree_cache.cache_unfinished_req(self.chunked_req) # chunked request keeps its rid but will get a new req_pool_idx self.req_to_token_pool.free(self.chunked_req.req_pool_idx) - self.batch_is_full = False + self.running_batch.batch_is_full = False # Filter batch last_bs = self.last_batch.batch_size() self.last_batch.filter_batch() if self.last_batch.batch_size() < last_bs: - self.batch_is_full = False + self.running_batch.batch_is_full = False # Merge the new batch into the running batch if not self.last_batch.is_empty(): - if self.running_batch is None: + if self.running_batch.is_empty(): self.running_batch = self.last_batch else: - # merge running_batch with prefill batch + # Merge running_batch with prefill batch self.running_batch.merge_batch(self.last_batch) new_batch = self.get_new_batch_prefill() @@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin): ret = new_batch else: # Run decode - if self.running_batch is None: - ret = None - else: + if not self.running_batch.is_empty(): self.running_batch = self.update_running_batch(self.running_batch) - ret = self.running_batch + ret = self.running_batch if not self.running_batch.is_empty() else None + else: + ret = None # Handle DP attention if self.server_args.enable_dp_attention: @@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin): # Handle the cases where prefill is not allowed if ( - self.batch_is_full or len(self.waiting_queue) == 0 + self.running_batch.batch_is_full or len(self.waiting_queue) == 0 ) and self.chunked_req is None: return None - running_bs = len(self.running_batch.reqs) if self.running_batch else 0 + running_bs = len(self.running_batch.reqs) if running_bs >= self.max_running_requests: - self.batch_is_full = True + self.running_batch.batch_is_full = True return None if self.enable_hierarchical_cache: @@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin): running_bs if self.is_mixed_chunk else 0, ) - is_chunked = self.chunked_req is not None - if is_chunked: + if self.chunked_req is not None: self.chunked_req.init_next_round_input() self.chunked_req = adder.add_chunked_req(self.chunked_req) if self.lora_paths: - lora_set = ( - set([req.lora_path for req in self.running_batch.reqs]) - if self.running_batch is not None - else set([]) - ) + lora_set = set([req.lora_path for req in self.running_batch.reqs]) + # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( @@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin): ) > self.max_loras_per_batch ): - self.batch_is_full = True + self.running_batch.batch_is_full = True break if running_bs + len(adder.can_run_list) >= self.max_running_requests: - self.batch_is_full = True + self.running_batch.batch_is_full = True break req.init_next_round_input( @@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin): if res == AddReqResult.NO_TOKEN: if self.enable_hierarchical_cache: # Set batch_is_full after making sure there are requests that can be served - self.batch_is_full = len(adder.can_run_list) > 0 or ( + self.running_batch.batch_is_full = len( + adder.can_run_list + ) > 0 or ( self.running_batch is not None and not self.running_batch.is_empty() ) else: - self.batch_is_full = True + self.running_batch.batch_is_full = True break # Update waiting queue @@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin): # Mixed-style chunked prefill if ( self.is_mixed_chunk - and self.running_batch is not None + and not self.running_batch.is_empty() and not (new_batch.return_logprob or self.running_batch.return_logprob) ): # TODO (lianmin): support return_logprob + mixed chunked prefill @@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin): self.running_batch.prepare_for_decode() new_batch.mix_with_running(self.running_batch) new_batch.decoding_reqs = self.running_batch.reqs - self.running_batch = None + self.running_batch = ScheduleBatch( + reqs=[], batch_is_full=self.running_batch.batch_is_full + ) else: new_batch.decoding_reqs = None @@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin): batch.filter_batch() if batch.is_empty(): - self.batch_is_full = False - return None + batch.batch_is_full = False + return batch # Check if decode out of memory if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( @@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ) if batch.batch_size() < initial_bs: - self.batch_is_full = False + batch.batch_is_full = False # Update batch tensors batch.prepare_for_decode() @@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) - if batch.is_empty(): - self.running_batch = None elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_idle(): @@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin): def flush_cache(self): """Flush the memory pool and cache.""" - if len(self.waiting_queue) == 0 and ( - self.running_batch is None or len(self.running_batch.reqs) == 0 - ): + if len(self.waiting_queue) == 0 and self.running_batch.is_empty(): self.cur_batch = None self.last_batch = None self.tree_cache.reset() @@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin): logging.warning( f"Cache not flushed because there are pending requests. " f"#queue-req: {len(self.waiting_queue)}, " - f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" + f"#running-req: {len(self.running_batch.reqs)}" ) if_success = False return if_success @@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin): def abort_request(self, recv_req: AbortReq): # Delete requests in the waiting queue - to_del = None + to_del = [] for i, req in enumerate(self.waiting_queue): - if req.rid == recv_req.rid: - to_del = i + if req.rid.startswith(recv_req.rid): + to_del.append(i) break - if to_del is not None: - del self.waiting_queue[to_del] + # Sort in reverse order to avoid index issues when deleting + for i in sorted(to_del, reverse=True): + req = self.waiting_queue.pop(i) logger.debug(f"Abort queued request. {req.rid=}") return # Delete requests in the running batch - if self.running_batch: - for req in self.running_batch.reqs: - if req.rid == recv_req.rid and not req.finished(): - logger.debug(f"Abort running request. {req.rid=}") - req.to_abort = True - break + for req in self.running_batch.reqs: + if req.rid.startswith(recv_req.rid) and not req.finished(): + logger.debug(f"Abort running request. {req.rid=}") + req.to_abort = True + return def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 8728b9e7e..e83dc0646 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -204,8 +204,17 @@ class SchedulerOutputProcessorMixin: continue if self.enable_overlap and req.finished(): - # Free the one delayed token - self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) + # Free the one extra delayed token + if self.page_size == 1: + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) + else: + # Only free when the extra token is in a new page + if ( + len(req.origin_input_ids) + len(req.output_ids) - 1 + ) % self.page_size == 0: + self.token_to_kv_pool_allocator.free( + batch.out_cache_loc[i : i + 1] + ) continue if batch.spec_algorithm.is_none(): diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index aaaa28e22..4a1f2d5c1 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -103,6 +103,9 @@ class TpModelWorkerClient: self.worker.model_runner.token_to_kv_pool_allocator, ) + def get_kv_cache(self): + return self.worker.model_runner.token_to_kv_pool + def forward_thread_func(self): try: with torch.get_device_module(self.device).stream(self.forward_stream): @@ -203,7 +206,7 @@ class TpModelWorkerClient: -(self.future_token_ids_ct + 1), -(self.future_token_ids_ct + 1 + bs), -1, - dtype=torch.int32, + dtype=torch.int64, device=self.device, ) self.future_token_ids_ct = ( diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 9386595a8..f370346e1 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Tuple +from typing import Any, List, Tuple class BasePrefixCache(ABC): @@ -26,24 +26,22 @@ class BasePrefixCache(ABC): pass @abstractmethod - def evict(self, num_tokens: int, evict_callback: Callable): + def evict(self, num_tokens: int): pass @abstractmethod - def inc_lock_ref(self, node): + def inc_lock_ref(self, node: Any): pass @abstractmethod - def dec_lock_ref(self, node): + def dec_lock_ref(self, node: Any): pass - @abstractmethod def evictable_size(self): - pass + return 0 - @abstractmethod def protected_size(self): - raise NotImplementedError() + return 0 def total_size(self): raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index a89fa93a1..3cb540fc6 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -1,7 +1,8 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple + +from typing import TYPE_CHECKING, Any, Callable, List, Tuple import torch @@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache): req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, ): - self.disable = True self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.entries: Dict[str, ChunkCacheEntry] = {} - - self.reset() def reset(self): - self.entries = {} + pass - def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]: - if rid not in self.entries: - return [], None - - entry = self.entries[rid] - max_prefix_len = len(key) - return entry.value[:max_prefix_len], entry - - def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): - if token_ids is None: - token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1 - else: - token_id_len = len(token_ids) + def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]: + return [], None + def cache_finished_req(self, req: Req): kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, :token_id_len + req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 ] self.req_to_token_pool.free(req.req_pool_idx) self.token_to_kv_pool_allocator.free(kv_indices) - if req.rid in self.entries: - del self.entries[req.rid] - def cache_unfinished_req(self, req: Req): - token_id_len = len(req.fill_ids) - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, :token_id_len + req.req_pool_idx, : len(req.fill_ids) ] - if req.rid not in self.entries: - self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices) - - entry = self.entries[req.rid] - entry.value = kv_indices + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later req.prefix_indices = kv_indices - req.last_node = entry def insert(self): raise NotImplementedError() - def evict(self, num_tokens: int, evict_callback: Callable): + def evict(self, num_tokens: int): pass - def inc_lock_ref(self, node): + def inc_lock_ref(self, node: Any): return 0 - def dec_lock_ref(self, node): - return 0 - - def evictable_size(self): - return 0 - - def pretty_print(self): - return "" - - def protected_size(self): + def dec_lock_ref(self, node: Any): return 0 def pretty_print(self): diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index f629bb751..6b4825994 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -7,13 +7,13 @@ from typing import List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController -from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPoolHost, ReqToTokenPool, TokenToKVPoolAllocator, ) -from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class HiRadixCache(RadixCache): def evictable_size(self): return self.evictable_size_ - def evict(self, num_tokens: int, evict_callback=None): + def evict(self, num_tokens: int): leaves = self._collect_leaves_device() heapq.heapify(leaves) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 4dfb72bca..d8ea694c5 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -129,6 +129,7 @@ class TokenToKVPoolAllocator: self.size = size self.dtype = dtype self.device = device + self.page_size = 1 self.free_slots = None self.is_not_in_free_group = True @@ -149,15 +150,14 @@ class TokenToKVPoolAllocator: select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - - return select_index.to(self.device, non_blocking=True) + return select_index def free(self, free_index: torch.Tensor): if free_index.numel() == 0: return if self.is_not_in_free_group: - self.free_slots = torch.concat((self.free_slots, free_index.cpu())) + self.free_slots = torch.concat((self.free_slots, free_index)) else: self.free_group.append(free_index) @@ -172,7 +172,9 @@ class TokenToKVPoolAllocator: def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32) + self.free_slots = torch.arange( + 1, self.size + 1, dtype=torch.int64, device=self.device + ) self.is_in_free_group = False self.free_group = [] @@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache): def __init__( self, size: int, + page_size: int, dtype: torch.dtype, head_num: int, head_dim: int, @@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache): enable_memory_saver: bool, ): self.size = size + self.page_size = page_size self.dtype = dtype self.device = device if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): @@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache): self._create_buffers() self.layer_transfer_counter = None + self.capture_mode = False + self.alt_stream = torch.cuda.Stream() k_size, v_size = self.get_kv_size_bytes() logger.info( @@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. self.k_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), dtype=self.store_dtype, device=self.device, ) for _ in range(self.layer_num) ] self.v_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), dtype=self.store_dtype, device=self.device, ) @@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache): cache_v.div_(v_scale) cache_k = cache_k.to(self.dtype) cache_v = cache_v.to(self.dtype) + if self.store_dtype != self.dtype: - self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) - self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) + cache_k = cache_k.view(self.store_dtype) + cache_v = cache_v.view(self.store_dtype) + + if self.capture_mode: + self.alt_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.alt_stream): + self.k_buffer[layer_id][loc] = cache_k + self.v_buffer[layer_id][loc] = cache_v + torch.cuda.current_stream().wait_stream(self.alt_stream) else: self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v +@torch.compile +def fused_downcast( + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + dtype: torch.dtype, + store_dtype: torch.dtype, + max_fp8: float, + min_fp8: float, +): + cache_k = cache_k / k_scale + cache_k = torch.clamp(cache_k, min_fp8, max_fp8) + cache_v = cache_v / v_scale + cache_v = torch.clamp(cache_v, min_fp8, max_fp8) + cache_k = cache_k.to(dtype) + cache_v = cache_v.to(dtype) + cache_k = cache_k.view(store_dtype) + cache_v = cache_v.view(store_dtype) + return cache_k, cache_v + + # This compiled version is slower in the unit test # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size @torch.compile(dynamic=True, backend=get_compiler_backend()) @@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache): def __init__( self, size: int, + page_size: int, dtype: torch.dtype, kv_lora_rank: int, qk_rope_head_dim: int, @@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache): with memory_saver_adapter.region(): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.kv_buffer = [ - torch.empty( - (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + torch.zeros( + (size + page_size, 1, kv_lora_rank + qk_rope_head_dim), dtype=self.store_dtype, device=device, ) @@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache): def __init__( self, size: int, + page_size: int, dtype: torch.dtype, head_num: int, head_dim: int, @@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache): enable_memory_saver: bool, ): self.size = size + self.page_size = page_size self.dtype = dtype self.device = device if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): @@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache): with memory_saver_adapter.region(): # [size, head_num, head_dim] for each layer self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + torch.zeros( + (size + page_size, head_num, head_dim), dtype=dtype, device=device + ) for _ in range(layer_num) ] self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + torch.zeros( + (size + page_size, head_num, head_dim), dtype=dtype, device=device + ) for _ in range(layer_num) ] # [size, head_num, heavy_channel_num] for each layer self.label_buffer = [ - torch.empty( + torch.zeros( (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device ) for _ in range(layer_num) @@ -528,7 +571,7 @@ class MHATokenToKVPoolHost: f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." ) - self.kv_buffer = torch.empty( + self.kv_buffer = torch.zeros( (2, self.layer_num, self.size, self.head_num, self.head_dim), dtype=self.dtype, device=self.device, @@ -548,9 +591,6 @@ class MHATokenToKVPoolHost: def get_flat_data(self, indices): return self.kv_buffer[:, :, indices] - def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[:, layer_id, indices] - def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, :, indices] = flat_data diff --git a/python/sglang/srt/mem_cache/paged_allocator.py b/python/sglang/srt/mem_cache/paged_allocator.py new file mode 100644 index 000000000..3b07aa2a3 --- /dev/null +++ b/python/sglang/srt/mem_cache/paged_allocator.py @@ -0,0 +1,283 @@ +""" +Copyright 2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Page-aligned memory pool. +""" + +import torch +import triton +import triton.language as tl + +from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.utils import get_bool_env_var, next_power_of_2 + + +@triton.jit +def alloc_extend_kernel( + pre_lens_ptr, + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + ret_values, + bs_upper: tl.constexpr, + page_size: tl.constexpr, + max_num_extend_tokens: tl.constexpr, +): + pid = tl.program_id(0) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) + pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid) + extend_lens = seq_lens - pre_lens + + seq_len = tl.load(seq_lens_ptr + pid) + pre_len = tl.load(pre_lens_ptr + pid) + extend_len = seq_len - pre_len + + sum_extend_lens = tl.sum(extend_lens) + output_start_loc = sum_extend_lens - extend_len + + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (pre_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( + pre_len + page_size - 1 + ) // page_size + sum_num_new_pages = tl.sum(num_new_pages) + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self + + # Return value + if pid == tl.num_programs(0) - 1: + merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to( + tl.int64 + ) + tl.store(ret_values, merged_value) + + # Part 1: fill the old partial page + last_loc = tl.load(last_loc_ptr + pid) + num_part1 = ( + min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len + ) + offset_one_page = tl.arange(0, page_size) + tl.store( + out_indices + output_start_loc + offset_one_page, + last_loc + 1 + offset_one_page, + mask=offset_one_page < num_part1, + ) + if pre_len + num_part1 == seq_len: + return + + # Part 2: fill the new full pages + num_part2 = ( + seq_len // page_size * page_size + - (pre_len + page_size - 1) // page_size * page_size + ) + + offset_many_page = tl.arange(0, max_num_extend_tokens) + page_start = tl.load( + free_page_ptr + new_page_start_loc + offset_many_page // page_size, + mask=offset_many_page < num_part2, + ) + tl.store( + out_indices + output_start_loc + num_part1 + offset_many_page, + page_start * page_size + offset_many_page % page_size, + mask=offset_many_page < num_part2, + ) + if pre_len + num_part1 + num_part2 == seq_len: + return + + # Part 3: fill the new partial page + num_part3 = seq_len - seq_len // page_size * page_size + start_loc = tl.load( + free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1 + ) + tl.store( + out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page, + start_loc * page_size + offset_one_page, + mask=offset_one_page < num_part3, + ) + + +@triton.jit +def alloc_decode_kernel( + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + ret_values, + bs_upper: tl.constexpr, + page_size: tl.constexpr, +): + pid = tl.program_id(0) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) + pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens) + + seq_len = tl.load(seq_lens_ptr + pid) + pre_len = seq_len - 1 + + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (pre_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( + pre_len + page_size - 1 + ) // page_size + sum_num_new_pages = tl.sum(num_new_pages) + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self + + # Return value + if pid == tl.num_programs(0) - 1: + tl.store(ret_values, sum_num_new_pages) + + if num_page_start_loc_self == 0: + last_loc = tl.load(last_loc_ptr + pid) + tl.store(out_indices + pid, last_loc + 1) + else: + page = tl.load(free_page_ptr + new_page_start_loc) + tl.store(out_indices + pid, page * page_size) + + +class PagedTokenToKVPoolAllocator: + """ + An allocator managing the indices to kv cache data. + + This class has the same interface as `TokenToKVPoolAllocator` but the output + of one request is always page-aligned. + + TODO: fuse last_loc into the kernel. + """ + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + device: str, + kvcache: KVCache, + ): + self.size = size + self.dtype = dtype + self.device = device + self.page_size = page_size + self.num_pages = size // page_size + + self.free_pages = None + self.is_not_in_free_group = True + self.free_group = [] + self.clear() + self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") + + self._kvcache = kvcache + self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) + + def available_size(self): + return len(self.free_pages) * self.page_size + + def alloc_extend( + self, + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + if self.debug_mode: + assert torch.all( + (last_loc + 1) % self.page_size == prefix_lens % self.page_size + ) + + bs = len(prefix_lens) + out_indices = torch.empty( + (extend_num_tokens,), dtype=torch.int64, device=self.device + ) + alloc_extend_kernel[(bs,)]( + prefix_lens, + seq_lens, + last_loc, + self.free_pages, + out_indices, + self.ret_values, + next_power_of_2(bs), + self.page_size, + next_power_of_2(extend_num_tokens), + ) + + merged_value = self.ret_values.item() + num_new_pages = merged_value >> 32 + if num_new_pages > len(self.free_pages): + return None + + self.free_pages = self.free_pages[num_new_pages:] + return out_indices + + def alloc_decode( + self, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + ): + if self.debug_mode: + assert torch.all( + (last_loc + 2) % self.page_size == seq_lens % self.page_size + ) + + bs = len(seq_lens) + out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) + alloc_decode_kernel[(bs,)]( + seq_lens, + last_loc, + self.free_pages, + out_indices, + self.ret_values, + next_power_of_2(bs), + self.page_size, + ) + + num_new_pages = self.ret_values.item() + if num_new_pages > len(self.free_pages): + return None + + self.free_pages = self.free_pages[num_new_pages:] + return out_indices + + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + + if self.is_not_in_free_group: + free_page_indices = torch.unique(free_index // self.page_size) + self.free_pages = torch.cat((free_page_indices, self.free_pages)) + else: + self.free_group.append(free_index) + + def free_group_begin(self): + self.is_not_in_free_group = False + self.free_group = [] + + def free_group_end(self): + self.is_not_in_free_group = True + if self.free_group: + self.free(torch.concat(self.free_group)) + + def clear(self): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.free_pages = torch.arange( + 1, self.num_pages + 1, dtype=torch.int64, device=self.device + ) + self.is_in_free_group = False + self.free_group = [] diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index d46ec4277..951f4d869 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache. import heapq import time from collections import defaultdict -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Tuple import torch @@ -67,7 +68,7 @@ class TreeNode: return self.last_access_time < other.last_access_time -def _key_match(key0: List, key1: List): +def _key_match_page_size1(key0: List, key1: List): i = 0 for k0, k1 in zip(key0, key1): if k0 != k1: @@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List): return i +def _key_match_paged(key0: List, key1: List, page_size: int): + min_len = min(len(key0), len(key1)) + + i = 0 + while i < min_len: + if key0[i : i + page_size] != key1[i : i + page_size]: + break + i += page_size + + return i + + class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, + page_size: int, disable: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.page_size = page_size self.disable = disable + + if self.token_to_kv_pool_allocator: + self.device = self.token_to_kv_pool_allocator.device + else: + self.device = torch.device("cpu") + + if self.page_size == 1: + self.key_match_fn = _key_match_page_size1 + self.get_child_key_fn = lambda key: key[0] + else: + self.key_match_fn = partial(_key_match_paged, page_size=page_size) + self.get_child_key_fn = lambda key: tuple(key[:page_size]) self.reset() ##### Public API ##### @@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache): The last node create a new child if the prefix is shorter than the last node's value. """ - if self.disable: - return [], self.root_node + if self.disable or len(key) == 0: + return ( + torch.empty( + (0,), + dtype=torch.int32, + device=self.device, + ), + self.root_node, + ) + + if self.page_size != 1: + page_aligned_len = len(key) // self.page_size * self.page_size + key = key[:page_aligned_len] value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.concat(value) else: - value = torch.tensor([], dtype=torch.int32) + value = torch.empty((0,), dtype=torch.int32, device=self.device) return value, last_node def insert(self, key: List, value=None): @@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache): value = [x for x in key] return self._insert_helper(self.root_node, key, value) - def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): + def cache_finished_req(self, req: Req): """Cache request when it finishes.""" if self.disable: - if token_ids is None: - token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1 - else: - token_ids_len = len(token_ids) - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, :token_ids_len + req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 ] self.token_to_kv_pool_allocator.free(kv_indices) self.req_to_token_pool.free(req.req_pool_idx) return - if token_ids is None: - token_ids = (req.origin_input_ids + req.output_ids)[:-1] + token_ids = (req.origin_input_ids + req.output_ids)[:-1] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] + if self.page_size != 1: + page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) + else: + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.clone() + # Radix Cache takes one ref in memory pool - new_prefix_len = self.insert(token_ids, kv_indices.clone()) + new_prefix_len = self.insert( + token_ids[:page_aligned_len], page_aligned_kv_indices + ) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) @@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) - def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): + def cache_unfinished_req(self, req: Req): """Cache request when it is unfinished.""" if self.disable: return - if token_ids is None: - token_ids = req.fill_ids - + token_ids = req.fill_ids kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] + if self.page_size != 1: + page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + else: + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.clone() + page_aligned_token_ids = token_ids[:page_aligned_len] + # Radix Cache takes one ref in memory pool - new_prefix_len = self.insert(token_ids, kv_indices.clone()) + new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) # The prefix indices could be updated, reuse it - new_indices, new_last_node = self.match_prefix(token_ids) - assert len(new_indices) == len(token_ids) + new_indices, new_last_node = self.match_prefix(page_aligned_token_ids) self.req_to_token_pool.write( (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), new_indices[len(req.prefix_indices) :], @@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache): self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) - req.prefix_indices = new_indices + + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + if self.page_size != 1: + req.prefix_indices = torch.cat( + [new_indices, kv_indices[len(new_indices) :]] + ) + else: + req.prefix_indices = new_indices req.last_node = new_last_node def pretty_print(self): @@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache): def total_size(self): return self._total_size_helper() - def evict(self, num_tokens: int, evict_callback: Callable): + def evict(self, num_tokens: int): if self.disable: return @@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache): if x.lock_ref > 0: continue - evict_callback(x.value) + self.token_to_kv_pool_allocator.free(x.value) num_evicted += len(x.value) self._delete_leaf(x) @@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache): # protected size refers to the size of the cache that is locked return self.protected_size_ + def all_values_flatten(self): + values = [] + + def _dfs_helper(node: TreeNode): + for _, child in node.children.items(): + values.append(child.value) + _dfs_helper(child) + + _dfs_helper(self.root_node) + return torch.concat(values) + ##### Internal Helper Functions ##### def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() + + child_key = self.get_child_key_fn(key) + value = [] - while len(key) > 0 and key[0] in node.children.keys(): - child = node.children[key[0]] + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] child.last_access_time = time.time() - prefix_len = _key_match(child.key, key) + prefix_len = self.key_match_fn(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) @@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache): value.append(child.value) node = child key = key[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) + return value, node def _split_node(self, key, child: TreeNode, split_len: int): # new_node -> child new_node = TreeNode() - new_node.children = {key[split_len]: child} + new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] @@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache): child.parent = new_node child.key = child.key[split_len:] child.value = child.value[split_len:] - new_node.parent.children[key[0]] = new_node + new_node.parent.children[self.get_child_key_fn(key)] = new_node return new_node def _insert_helper(self, node: TreeNode, key: List, value): @@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache): if len(key) == 0: return 0 + child_key = self.get_child_key_fn(key) + total_prefix_length = 0 - while len(key) > 0 and key[0] in node.children.keys(): - node = node.children[key[0]] + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] node.last_access_time = time.time() - prefix_len = _key_match(node.key, key) + prefix_len = self.key_match_fn(node.key, key) total_prefix_length += prefix_len key = key[prefix_len:] value = value[prefix_len:] @@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache): new_node = self._split_node(node.key, node, prefix_len) node = new_node + if len(key): + child_key = self.get_child_key_fn(key) + if len(key): new_node = TreeNode() new_node.parent = node new_node.key = key new_node.value = value - node.children[key[0]] = new_node + node.children[child_key] = new_node self.evictable_size_ += len(value) return total_prefix_length @@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache): current_node.key[:10], f"r={current_node.lock_ref}", ) - for _, child in current_node.children.items(): + for key, child in current_node.children.items(): stack.append((child, current_indent + 2)) + assert key == self.get_child_key_fn( + child.key + ), f"{key=}, {self.get_child_key_fn(child.key)=}" + def _delete_leaf(self, node): for k, v in node.parent.children.items(): if v == node: @@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache): if __name__ == "__main__": - tree = RadixCache(None, None, False) + tree = RadixCache(None, None, page_size=1, disable=False) tree.insert("Hello") tree.insert("Hello") diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 5ae558056..99f54b7f9 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -264,11 +264,15 @@ class CudaGraphRunner: def model_capture_mode(self): if hasattr(self.model_runner.model, "capture_mode"): self.model_runner.model.capture_mode = True + if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"): + self.model_runner.token_to_kv_pool.capture_mode = True yield if hasattr(self.model_runner.model, "capture_mode"): self.model_runner.model.capture_mode = False + if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"): + self.model_runner.token_to_kv_pool.capture_mode = False def can_run(self, forward_batch: ForwardBatch): if self.enable_dp_attention: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 70b8c6f46..11d90882b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,12 +38,12 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import get_compiler_backend +from sglang.srt.utils import get_compiler_backend, next_power_of_2 if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch - from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool + from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -51,9 +51,8 @@ if TYPE_CHECKING: class ForwardMode(IntEnum): - # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. - PREFILL = auto() # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). + # It is also called "prefill" in common terminology. EXTEND = auto() # Decode one token. DECODE = auto() @@ -153,6 +152,12 @@ class ForwardBatch: top_logprobs_nums: Optional[List[int]] = None token_ids_logprobs: Optional[List[List[int]]] = None + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + temperature: torch.Tensor = None + top_p_normalized_logprobs: bool = False + top_p: torch.Tensor = None + # Position information positions: torch.Tensor = None @@ -189,7 +194,7 @@ class ForwardBatch: # Attention backend req_to_token_pool: ReqToTokenPool = None - token_to_kv_pool: BaseTokenToKVPool = None + token_to_kv_pool: KVCache = None attn_backend: AttentionBackend = None # For DP attention @@ -229,7 +234,6 @@ class ForwardBatch: extend_input_logprob_token_ids_gpu = ( batch.extend_input_logprob_token_ids.to(device, non_blocking=True) ) - ret = cls( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), @@ -417,8 +421,8 @@ def compute_position_kernel( prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0 seq_len = tl.load(extend_seq_lens + pid) - # TODO: optimize this? - cumsum_start = 0 + # NOTE: This can be slow for large bs + cumsum_start = tl.cast(0, tl.int64) for i in range(pid): cumsum_start += tl.load(extend_seq_lens + i) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 916d595ed..1b447b2b8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( ReqToTokenPool, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model @@ -430,7 +431,7 @@ class ModelRunner: self.model_config.model_path = model_path load_config = LoadConfig(load_format=load_format) - # Only support the DefaultModelLoader for now + # Only support DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): message = f"Failed to get model loader: {loader}." @@ -732,6 +733,7 @@ class ModelRunner: ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, + page_size=self.page_size, dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, @@ -742,6 +744,7 @@ class ModelRunner: elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.max_total_num_tokens, + page_size=self.page_size, dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, @@ -753,6 +756,7 @@ class ModelRunner: else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, + page_size=self.page_size, dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, @@ -762,12 +766,21 @@ class ModelRunner: ) if self.token_to_kv_pool_allocator is None: - self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( - self.max_total_num_tokens, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - ) + if self.page_size == 1: + self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) + else: + self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) else: assert self.is_draft_worker diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 13e1a5cd0..dc927f096 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -220,6 +220,8 @@ class ServerArgs: else: self.chunked_prefill_size = 8192 + assert self.chunked_prefill_size % self.page_size == 0 + # Set cuda graph max batch size if self.cuda_graph_max_bs is None: # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 54909ac9d..19c06b607 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1554,6 +1554,13 @@ def set_cuda_arch(): os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" +def next_power_of_2(n: int): + return 1 << (n - 1).bit_length() if n > 0 else 1 + + +setattr(triton, "next_power_of_2", next_power_of_2) + + def add_prefix(name: str, prefix: str) -> str: """Add a weight path prefix to a module name. diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d387fd710..60ea570cb 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -45,6 +45,7 @@ suites = { TestFile("test_no_overlap_scheduler.py", 262), TestFile("test_openai_server.py", 124), TestFile("test_penalty.py", 41), + TestFile("test_page_size.py", 60), TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 167), TestFile("test_reasoning_content.py", 89), diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 31c9cc71b..5de03a461 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -42,7 +42,8 @@ class TestDPAttention(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.5) def test_mgsm_en(self): args = SimpleNamespace( @@ -54,7 +55,8 @@ class TestDPAttention(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.8 + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.8) if __name__ == "__main__": diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index c9145fe6f..f0ee63e6b 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -184,6 +184,7 @@ class TestGPTQModelDynamicWithMarlin(unittest.TestCase): "text": "The capital of France is", "sampling_params": { "max_new_tokens": max_new_tokens, + "temperature": 0.001, }, }, ) diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py index ba43c2ba1..bb304ed29 100644 --- a/test/srt/test_mla_deepseek_v3.py +++ b/test/srt/test_mla_deepseek_v3.py @@ -13,7 +13,7 @@ from sglang.test.test_utils import ( ) -class TestDeepseekV3(unittest.TestCase): +class TestMLADeepseekV3(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = "lmsys/sglang-ci-dsv3-test" diff --git a/test/srt/test_page_size.py b/test/srt/test_page_size.py new file mode 100644 index 000000000..0fabfbfa4 --- /dev/null +++ b/test/srt/test_page_size.py @@ -0,0 +1,46 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestPageSize(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["SGLANG_DEBUG_MEMORY_POOL"] = "1" + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--page-size", 4, "--chunked-prefill-size", 128], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py index 5f169cdb6..3ca8620be 100644 --- a/test/srt/test_retract_decode.py +++ b/test/srt/test_retract_decode.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -14,6 +15,8 @@ from sglang.test.test_utils import ( class TestRetractDecode(unittest.TestCase): @classmethod def setUpClass(cls): + os.environ["SGLANG_TEST_RETRACT"] = "1" + cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( @@ -37,5 +40,20 @@ class TestRetractDecode(unittest.TestCase): self.assertGreaterEqual(metrics["score"], 0.65) +class TestRetractDecodeChunkCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["SGLANG_TEST_RETRACT"] = "1" + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--disable-radix-cache", "--chunked-prefill-size", 128], + ) + + if __name__ == "__main__": unittest.main()