From 9379da77de43bf410eeeff73362f6c425e56a6f8 Mon Sep 17 00:00:00 2001 From: Hanming Lu <69857889+hanming-lu@users.noreply.github.com> Date: Sun, 13 Jul 2025 12:31:07 -0700 Subject: [PATCH] SWA Prefix Cache (#7367) Co-authored-by: Ying Sheng --- python/sglang/srt/configs/model_config.py | 5 +- python/sglang/srt/disaggregation/decode.py | 10 +- .../layers/attention/flashinfer_backend.py | 30 + python/sglang/srt/managers/schedule_batch.py | 158 ++- python/sglang/srt/managers/schedule_policy.py | 97 +- python/sglang/srt/managers/scheduler.py | 228 +++- python/sglang/srt/managers/tp_worker.py | 14 + .../srt/managers/tp_worker_overlap_thread.py | 11 + python/sglang/srt/mem_cache/allocator.py | 17 +- .../sglang/srt/mem_cache/base_prefix_cache.py | 16 +- python/sglang/srt/mem_cache/chunk_cache.py | 7 +- .../sglang/srt/mem_cache/swa_radix_cache.py | 1025 +++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 77 +- python/sglang/srt/models/gemma2.py | 1 + python/sglang/srt/server_args.py | 28 +- test/srt/test_swa_unittest.py | 176 +++ 16 files changed, 1742 insertions(+), 158 deletions(-) create mode 100644 python/sglang/srt/mem_cache/swa_radix_cache.py create mode 100644 test/srt/test_swa_unittest.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ecff10244..1a62178b9 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -711,7 +711,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int) i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 ] else: - raise ValueError( - "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration" - ) + swa_attention_layer_ids = None + full_attention_layer_ids = None return swa_attention_layer_ids, full_attention_layer_ids diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index a31afabe5..ddc405c48 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -439,7 +439,15 @@ class DecodePreallocQueue: else 0 ) - allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max( + if self.scheduler.model_config.is_hybrid: + available_size = min( + self.token_to_kv_pool_allocator.full_available_size(), + self.token_to_kv_pool_allocator.swa_available_size(), + ) + else: + available_size = self.token_to_kv_pool_allocator.available_size() + + allocatable_tokens = available_size - max( # preserve some space for future decode self.num_reserved_decode_tokens * ( diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7d62f7821..f65e533d9 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.utils import is_sm100_supported +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import is_flashinfer_available, next_power_of_2 @@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode: self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator # Dispatch the update function if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: @@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum_tmp = seq_lens_sum kv_start_idx_tmp = None + use_sliding_window_kv_pool = wrapper_id == 0 and isinstance( + self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator + ) + self.call_begin_forward( decode_wrappers[wrapper_id], req_pool_indices, @@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode: self.kv_indptr[wrapper_id], kv_start_idx_tmp, spec_info, + use_sliding_window_kv_pool=use_sliding_window_kv_pool, ) def update_cross_attention( @@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode: kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + use_sliding_window_kv_pool: bool = False, ): if spec_info is None: bs = len(req_pool_indices) @@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 + if use_sliding_window_kv_pool: + kv_last_index = kv_indptr[-1] + kv_indices[:kv_last_index] = ( + self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( + kv_indices[:kv_last_index] + ) + ) + wrapper.begin_forward( kv_indptr, kv_indices, @@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill: self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged # Dispatch the update function @@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill: paged_kernel_lens_sum = seq_lens_sum kv_start_idx = seq_lens - paged_kernel_lens + use_sliding_window_kv_pool = wrapper_id == 0 and isinstance( + self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator + ) self.call_begin_forward( self.prefill_wrapper_ragged, @@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill: self.qo_indptr[wrapper_id], use_ragged, spec_info, + use_sliding_window_kv_pool=use_sliding_window_kv_pool, ) def update_cross_attention( @@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill: qo_indptr: torch.Tensor, use_ragged: bool, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + use_sliding_window_kv_pool: bool = False, ): bs = len(seq_lens) if spec_info is None: @@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill: q_data_type=self.q_data_type, ) + if use_sliding_window_kv_pool: + kv_last_index = kv_indptr[-1] + kv_indices[:kv_last_index] = ( + self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( + kv_indices[:kv_last_index] + ) + ) + # cached part wrapper_paged.begin_forward( qo_indptr, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3d3c177ca..1a48b0553 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.allocator import ( + BaseTokenToKVPoolAllocator, + SWATokenToKVPoolAllocator, +) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -527,6 +531,8 @@ class Req: self.last_node: Any = None self.last_host_node: Any = None self.host_hit_length = 0 + # The node to lock until for swa radix tree lock ref + self.swa_uuid_for_lock: Optional[int] = None # Whether or not if it is chunked. It increments whenever # it is chunked, and decrement whenever chunked request is @@ -745,6 +751,7 @@ class Req: def reset_for_retract(self): self.prefix_indices = [] self.last_node = None + self.swa_uuid_for_lock = None self.extend_input_len = 0 self.is_retracted = True self.input_token_logprobs = None @@ -813,6 +820,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): req_to_token_pool: ReqToTokenPool = None token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None tree_cache: BasePrefixCache = None + is_hybrid: bool = False # Batch configs model_config: ModelConfig = None @@ -918,11 +926,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ): return_logprob = any(req.return_logprob for req in reqs) + is_hybrid = False + if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): + assert isinstance(tree_cache, SWARadixCache) or isinstance( + tree_cache, SWAChunkCache + ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator" + is_hybrid = True + return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool_allocator=token_to_kv_pool_allocator, tree_cache=tree_cache, + is_hybrid=is_hybrid, model_config=model_config, enable_overlap=enable_overlap, return_logprob=return_logprob, @@ -953,9 +969,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): return req_pool_indices def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): - if self.token_to_kv_pool_allocator.available_size() < num_tokens: - if self.tree_cache is not None: - self.tree_cache.evict(num_tokens) + self._evict_tree_cache_if_needed(num_tokens) if backup_state: state = self.token_to_kv_pool_allocator.backup_state() @@ -966,7 +980,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): error_msg = ( f"{phase_str} out of memory. Try to lower your batch size.\n" f"Try to allocate {num_tokens} tokens.\n" - f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" + f"{self._available_and_evictable_str()}" ) logger.error(error_msg) if self.tree_cache is not None: @@ -986,16 +1000,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_num_tokens: int, backup_state: bool = False, ): - if ( - self.token_to_kv_pool_allocator.available_size() - < extend_num_tokens + num_tokens = ( + extend_num_tokens + len(seq_lens) * self.token_to_kv_pool_allocator.page_size - ): - if self.tree_cache is not None: - self.tree_cache.evict( - extend_num_tokens - + len(seq_lens) * self.token_to_kv_pool_allocator.page_size, - ) + ) + self._evict_tree_cache_if_needed(num_tokens) if backup_state: state = self.token_to_kv_pool_allocator.backup_state() @@ -1007,9 +1016,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): error_msg = ( f"Prefill out of memory. Try to lower your batch size.\n" f"Try to allocate {extend_num_tokens} tokens.\n" - f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" - f"{self.token_to_kv_pool_allocator.available_size()=}\n" - f"{self.tree_cache.evictable_size()=}\n" + f"{self._available_and_evictable_str()}" ) logger.error(error_msg) raise RuntimeError(error_msg) @@ -1025,14 +1032,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): last_loc: torch.Tensor, backup_state: bool = False, ): - if self.tree_cache is not None: - if ( - self.token_to_kv_pool_allocator.available_size() - < len(seq_lens) * self.token_to_kv_pool_allocator.page_size - ): - self.tree_cache.evict( - len(seq_lens) * self.token_to_kv_pool_allocator.page_size, - ) + num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size + + self._evict_tree_cache_if_needed(num_tokens) if backup_state: state = self.token_to_kv_pool_allocator.backup_state() @@ -1042,9 +1044,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): error_msg = ( f"Decode out of memory. Try to lower your batch size.\n" f"Try to allocate {len(seq_lens)} tokens.\n" - f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" - f"{self.token_to_kv_pool_allocator.available_size()=}\n" - f"{self.tree_cache.evictable_size()=}\n" + f"{self._available_and_evictable_str()}" ) logger.error(error_msg) raise RuntimeError(error_msg) @@ -1181,7 +1181,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices ) if isinstance(self.tree_cache, SWAChunkCache): - self.tree_cache.evict( + self.tree_cache.evict_swa( req, pre_len, self.model_config.attention_chunk_size ) @@ -1371,17 +1371,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) def check_decode_mem(self, buf_multiplier=1): - tokens_required = ( + num_tokens = ( self.new_page_count_next_decode() * buf_multiplier * self.token_to_kv_pool_allocator.page_size ) - if self.token_to_kv_pool_allocator.available_size() >= tokens_required: - return True - self.tree_cache.evict(tokens_required) - - return self.token_to_kv_pool_allocator.available_size() >= tokens_required + self._evict_tree_cache_if_needed(num_tokens) + return self._is_available_size_sufficient(num_tokens) def retract_decode(self, server_args: ServerArgs): """Retract the decoding requests when there is not enough memory.""" @@ -1414,19 +1411,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode ) + def _get_available_size(): + if self.is_hybrid: + return min( + self.token_to_kv_pool_allocator.full_available_size(), + self.token_to_kv_pool_allocator.swa_available_size(), + ) + else: + return self.token_to_kv_pool_allocator.available_size() + retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True while ( - self.token_to_kv_pool_allocator.available_size() - < get_required_tokens(len(sorted_indices)) + _get_available_size() < get_required_tokens(len(sorted_indices)) or first_iter ): if len(sorted_indices) == 1: # Corner case: only one request left - assert ( - self.token_to_kv_pool_allocator.available_size() > 0 - ), "No space left for only one request" + if self.is_hybrid: + full_available_size = ( + self.token_to_kv_pool_allocator.full_available_size() + ) + swa_available_size = ( + self.token_to_kv_pool_allocator.swa_available_size() + ) + assert ( + full_available_size > 0 and swa_available_size > 0 + ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}" + else: + assert ( + self.token_to_kv_pool_allocator.available_size() > 0 + ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}" break first_iter = False @@ -1458,15 +1474,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.req_to_token_pool.free(req.req_pool_idx) # release the last node - self.tree_cache.dec_lock_ref(req.last_node) + if self.is_hybrid: + self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) + else: + self.tree_cache.dec_lock_ref(req.last_node) # NOTE(lsyin): we should use the newly evictable memory instantly. - residual_size = ( - len(sorted_indices) * global_config.retract_decode_steps - - self.token_to_kv_pool_allocator.available_size() - ) - residual_size = max(0, residual_size) - self.tree_cache.evict(residual_size) + num_tokens = len(sorted_indices) * global_config.retract_decode_steps + self._evict_tree_cache_if_needed(num_tokens) req.reset_for_retract() @@ -1559,7 +1574,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # free memory if isinstance(self.tree_cache, SWAChunkCache): for req in self.reqs: - self.tree_cache.evict( + self.tree_cache.evict_swa( req, req.seqlen - 1, self.model_config.attention_chunk_size ) @@ -1778,6 +1793,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): is_extend_in_batch=self.is_extend_in_batch, ) + def _evict_tree_cache_if_needed( + self, + num_tokens: int, + ) -> None: + if isinstance(self.tree_cache, SWAChunkCache): + return + + if self.is_hybrid: + full_available_size = self.token_to_kv_pool_allocator.full_available_size() + swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() + + if full_available_size < num_tokens or swa_available_size < num_tokens: + if self.tree_cache is not None: + full_num_tokens = max(0, num_tokens - full_available_size) + swa_num_tokens = max(0, num_tokens - swa_available_size) + self.tree_cache.evict(full_num_tokens, swa_num_tokens) + else: + if self.token_to_kv_pool_allocator.available_size() < num_tokens: + if self.tree_cache is not None: + self.tree_cache.evict(num_tokens) + + def _is_available_size_sufficient(self, num_tokens: int) -> bool: + if self.is_hybrid: + return ( + self.token_to_kv_pool_allocator.full_available_size() >= num_tokens + and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens + ) + else: + return self.token_to_kv_pool_allocator.available_size() >= num_tokens + + def _available_and_evictable_str(self) -> str: + if self.is_hybrid: + full_available_size = self.token_to_kv_pool_allocator.full_available_size() + swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() + full_evictable_size = self.tree_cache.full_evictable_size() + swa_evictable_size = self.tree_cache.swa_evictable_size() + return ( + f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" + f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" + f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n" + f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n" + ) + else: + available_size = self.token_to_kv_pool_allocator.available_size() + evictable_size = self.tree_cache.evictable_size() + return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" + def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index ba3dd8d4e..c07df2150 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode @@ -311,21 +312,43 @@ class PrefillAdder: ] ) - @property - def rem_total_tokens(self): - return ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() - - self.rem_total_token_offset + self.is_hybrid = isinstance( + self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator ) + @property + def rem_total_tokens(self): + if self.is_hybrid: + available_and_evictable = min( + self.token_to_kv_pool_allocator.full_available_size() + + self.tree_cache.full_evictable_size(), + self.token_to_kv_pool_allocator.swa_available_size() + + self.tree_cache.swa_evictable_size(), + ) + else: + available_and_evictable = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.evictable_size() + ) + + return available_and_evictable - self.rem_total_token_offset + @property def cur_rem_tokens(self): - return ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() - - self.cur_rem_token_offset - ) + if self.is_hybrid: + available_and_evictable = min( + self.token_to_kv_pool_allocator.full_available_size() + + self.tree_cache.full_evictable_size(), + self.token_to_kv_pool_allocator.swa_available_size() + + self.tree_cache.swa_evictable_size(), + ) + else: + available_and_evictable = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.evictable_size() + ) + + return available_and_evictable - self.cur_rem_token_offset def ceil_paged_tokens(self, tokens: int) -> int: return -(-tokens // self.page_size) * self.page_size @@ -376,11 +399,18 @@ class PrefillAdder: @contextmanager def _lock_node(self, last_node: TreeNode): - try: - self.tree_cache.inc_lock_ref(last_node) - yield None - finally: - self.tree_cache.dec_lock_ref(last_node) + if self.is_hybrid: + try: + swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node) + yield None + finally: + self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock) + else: + try: + self.tree_cache.inc_lock_ref(last_node) + yield None + finally: + self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool): # Early exit if no enough tokens for the input tokens @@ -422,16 +452,19 @@ class PrefillAdder: else: add_req_state(req, insert_sort=True) - cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids) - tokens_freed = 0 - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - # tokens_left gives a reservative calculation as the last token is not stored - bs = len(self.req_states) - i - min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs - # reserve tokens for corner cases - if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs: - return AddReqResult.NO_TOKEN - tokens_freed += tokens_occupied + if not self.is_hybrid: + # Skip this logic for swa. The SWA has different memory management, and + # this mechanism is underestimating the memory usage. + cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids) + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + # tokens_left gives a reservative calculation as the last token is not stored + bs = len(self.req_states) - i + min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs + # reserve tokens for corner cases + if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs: + return AddReqResult.NO_TOKEN + tokens_freed += tokens_occupied if ( self.rem_chunk_tokens is None # chunked prefill is disabled @@ -499,7 +532,11 @@ class PrefillAdder: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: # Non-chunked prefill self.can_run_list.append(req) - self.tree_cache.inc_lock_ref(req.last_node) + if self.is_hybrid: + swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node) + req.swa_uuid_for_lock = swa_uuid_for_lock + else: + self.tree_cache.inc_lock_ref(req.last_node) self._update_prefill_budget( prefix_len, input_tokens, @@ -520,7 +557,11 @@ class PrefillAdder: self.can_run_list.append(req) self.new_chunked_req = req - self.tree_cache.inc_lock_ref(req.last_node) + if self.is_hybrid: + swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node) + req.swa_uuid_for_lock = swa_uuid_for_lock + else: + self.tree_cache.inc_lock_ref(req.last_node) self._update_prefill_budget(prefix_len, trunc_len, 0) return self.budget_state() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 452a6d5ab..afb4b870d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import validate_input_length -from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.reasoning_parser import ReasoningParser @@ -390,6 +390,14 @@ class Scheduler( global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) + # Hybrid + self.is_hybrid = self.tp_worker.is_hybrid + if self.is_hybrid: + self.sliding_window_size = self.tp_worker.sliding_window_size + self.full_tokens_per_layer, self.swa_tokens_per_layer = ( + self.tp_worker.get_tokens_per_layer_info() + ) + # Print debug info if tp_rank == 0: avail_mem = get_available_gpu_memory( @@ -570,7 +578,7 @@ class Scheduler( server_args.chunked_prefill_size is not None and server_args.disable_radix_cache ): - if self.model_config.is_hybrid: + if self.is_hybrid: ChunkCacheClass = SWAChunkCache else: ChunkCacheClass = ChunkCache @@ -603,6 +611,17 @@ class Scheduler( self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter ) + elif self.is_hybrid: + assert ( + self.server_args.disaggregation_mode == "null" + ), "Hybrid mode does not support disaggregation yet" + self.tree_cache = SWARadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + sliding_window_size=self.sliding_window_size, + page_size=self.page_size, + disable=server_args.disable_radix_cache, + ) else: self.tree_cache = RadixCache( @@ -774,6 +793,7 @@ class Scheduler( else: # When the server is idle, do self-check and re-init some states self.check_memory() + self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() @@ -819,6 +839,7 @@ class Scheduler( elif batch is None: # When the server is idle, do self-check and re-init some states self.check_memory() + self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() @@ -955,6 +976,7 @@ class Scheduler( # When the server is idle, self-check and re-init some states if server_is_idle: self.check_memory() + self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() @@ -1306,9 +1328,26 @@ class Scheduler( self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_prefill_tokens = adder.log_input_tokens - usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage( - self.tree_cache.evictable_size() - ) + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + _, + _, + _, + _, + ) = self._get_swa_token_info() + num_used = max(full_num_used, swa_num_used) + token_usage = max(full_token_usage, swa_token_usage) + token_msg = ( + f"full token usage: {full_token_usage:.2f}, " + f"swa token usage: {swa_token_usage:.2f}, " + ) + else: + num_used, token_usage, _, _ = self._get_token_info() + token_msg = f"token usage: {token_usage:.2f}, " num_new_seq = len(can_run_list) f = ( @@ -1316,7 +1355,7 @@ class Scheduler( f"#new-seq: {num_new_seq}, " f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " - f"{usage_msg}" + f"{token_msg}" ) if self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -1338,7 +1377,7 @@ class Scheduler( ) self.stats.num_running_reqs = running_bs self.stats.num_used_tokens = num_used - self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) + self.stats.token_usage = round(token_usage, 2) self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.cache_hit_rate = cache_hit_rate @@ -1361,16 +1400,35 @@ class Scheduler( self.last_gen_throughput = self.num_generated_tokens / gap_latency self.num_generated_tokens = 0 num_running_reqs = len(batch.reqs) - usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage( - self.tree_cache.evictable_size() - ) + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + _, + _, + _, + _, + ) = self._get_swa_token_info() + num_used = max(full_num_used, swa_num_used) + token_usage = max(full_token_usage, swa_token_usage) + token_msg = ( + f"#full token: {full_num_used}, " + f"full token usage: {full_token_usage:.2f}, " + f"#swa token: {swa_num_used}, " + f"swa token usage: {swa_token_usage:.2f}, " + ) + else: + num_used, token_usage, _, _ = self._get_token_info() + token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " if RECORD_STEP_TIME: self.step_time_dict[num_running_reqs].append( gap_latency / self.server_args.decode_log_interval ) - msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}" + msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}" if self.spec_algorithm.is_none(): spec_accept_length = 0 @@ -1398,7 +1456,7 @@ class Scheduler( if self.enable_metrics: self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used - self.stats.token_usage = num_used / self.max_total_num_tokens + self.stats.token_usage = round(token_usage, 2) self.stats.cache_hit_rate = 0.0 self.stats.gen_throughput = self.last_gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) @@ -1409,24 +1467,34 @@ class Scheduler( self._publish_kv_events() def check_memory(self): - if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): - available_token_size = self.token_to_kv_pool_allocator.full_available_size() - else: - available_token_size = self.token_to_kv_pool_allocator.available_size() - available_size = available_token_size + self.tree_cache.evictable_size() - protected_size = self.tree_cache.protected_size() - memory_leak = available_size != ( - self.max_total_num_tokens - if not self.enable_hierarchical_cache - else self.max_total_num_tokens - protected_size - ) - if memory_leak: - msg = ( - "token_to_kv_pool_allocator memory leak detected! " - f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" - f"{available_token_size=}\n" - f"{self.tree_cache.evictable_size()=}\n" + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + _, + _, + full_available_size, + full_evictable_size, + swa_available_size, + swa_evictable_size, + ) = self._get_swa_token_info() + memory_leak = full_num_used != 0 or swa_num_used != 0 + token_msg = ( + f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n" + f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n" ) + else: + _, _, available_size, evictable_size = self._get_token_info() + protected_size = self.tree_cache.protected_size() + memory_leak = (available_size + evictable_size) != ( + self.max_total_num_tokens + if not self.enable_hierarchical_cache + else self.max_total_num_tokens - protected_size + ) + token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" + + if memory_leak: + msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" raise ValueError(msg) if self.disaggregation_mode == DisaggregationMode.DECODE: @@ -1450,20 +1518,66 @@ class Scheduler( and time.perf_counter() > self.metrics_collector.last_log_time + 30 ): # During idle time, also collect metrics every 30 seconds. - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() - ) + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + _, + _, + _, + _, + ) = self._get_swa_token_info() + num_used = max(full_num_used, swa_num_used) + token_usage = max(full_token_usage, swa_token_usage) + else: + num_used, token_usage, _, _ = self._get_token_info() num_running_reqs = len(self.running_batch.reqs) self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used - self.stats.token_usage = num_used / self.max_total_num_tokens + self.stats.token_usage = round(token_usage, 2) self.stats.gen_throughput = 0 self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.metrics_collector.log_stats(self.stats) self._publish_kv_events() + def check_tree_cache(self): + if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache): + self.tree_cache.sanity_check() + + def _get_token_info(self): + available_size = self.token_to_kv_pool_allocator.available_size() + evictable_size = self.tree_cache.evictable_size() + num_used = self.max_total_num_tokens - (available_size + evictable_size) + token_usage = num_used / self.max_total_num_tokens + return num_used, token_usage, available_size, evictable_size + + def _get_swa_token_info(self): + full_available_size = self.token_to_kv_pool_allocator.full_available_size() + full_evictable_size = self.tree_cache.full_evictable_size() + swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() + swa_evictable_size = self.tree_cache.swa_evictable_size() + full_num_used = self.full_tokens_per_layer - ( + full_available_size + full_evictable_size + ) + swa_num_used = self.swa_tokens_per_layer - ( + swa_available_size + swa_evictable_size + ) + full_token_usage = full_num_used / self.full_tokens_per_layer + swa_token_usage = swa_num_used / self.swa_tokens_per_layer + return ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + full_available_size, + full_evictable_size, + swa_available_size, + swa_evictable_size, + ) + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch chunked_req_to_exclude = set() @@ -2042,11 +2156,30 @@ class Scheduler( if not disable_request_logging(): # Print batch size and memory pool info to check whether there are de-sync issues. + if self.is_hybrid: + ( + _, + _, + _, + _, + full_available_size, + full_evictable_size, + swa_available_size, + swa_evictable_size, + ) = self._get_swa_token_info() + info_msg = ( + f"{full_available_size=}, " + f"{full_evictable_size=}, " + f"{swa_available_size=}, " + f"{swa_evictable_size=}, " + ) + else: + _, _, available_size, evictable_size = self._get_token_info() + info_msg = f"{available_size=}, " f"{evictable_size=}, " logger.error( f"{self.cur_batch.batch_size()=}, " f"{self.cur_batch.reqs=}, " - f"{self.token_to_kv_pool_allocator.available_size()=}, " - f"{self.tree_cache.evictable_size()=}, " + f"{info_msg}" ) pyspy_dump_schedulers() @@ -2101,11 +2234,24 @@ class Scheduler( def get_load(self): # TODO(lsyin): use dynamically maintained num_waiting_tokens - load = ( - self.max_total_num_tokens - - self.token_to_kv_pool_allocator.available_size() - - self.tree_cache.evictable_size() - ) + if self.is_hybrid: + load_full = ( + self.full_tokens_per_layer + - self.token_to_kv_pool_allocator.full_available_size() + - self.tree_cache.full_evictable_size() + ) + load_swa = ( + self.swa_tokens_per_layer + - self.token_to_kv_pool_allocator.swa_available_size() + - self.tree_cache.swa_evictable_size() + ) + load = max(load_full, load_swa) + else: + load = ( + self.max_total_num_tokens + - self.token_to_kv_pool_allocator.available_size() + - self.tree_cache.evictable_size() + ) load += sum(len(req.origin_input_ids) for req in self.waiting_queue) if self.disaggregation_mode == DisaggregationMode.PREFILL: load += sum( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a0a33741d..ff20ea01e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -174,6 +174,20 @@ class TpModelWorker: self.model_runner.token_to_kv_pool.size, ) + @property + def sliding_window_size(self) -> Optional[int]: + return self.model_runner.sliding_window_size + + @property + def is_hybrid(self) -> bool: + return self.model_runner.is_hybrid is not None + + def get_tokens_per_layer_info(self): + return ( + self.model_runner.full_max_total_num_tokens, + self.model_runner.swa_max_total_num_tokens, + ) + def get_pad_input_ids_func(self): return getattr(self.model_runner.model, "pad_input_ids", None) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3bd699976..08d2dd477 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -102,6 +102,17 @@ class TpModelWorkerClient: def get_worker_info(self): return self.worker.get_worker_info() + def get_tokens_per_layer_info(self): + return self.worker.get_tokens_per_layer_info() + + @property + def sliding_window_size(self) -> Optional[int]: + return self.worker.sliding_window_size + + @property + def is_hybrid(self) -> bool: + return self.worker.is_hybrid + def get_pad_input_ids_func(self): return self.worker.get_pad_input_ids_func() diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 6d06fa103..d086535f4 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC): def debug_print(self) -> str: return "" - def log_usage(self, evictable_size: int = 0): - num_used = self.size - (self.available_size() + evictable_size) - msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, " - return msg, num_used - def available_size(self): return len(self.free_pages) * self.page_size @@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping def available_size(self): - return min(self.full_available_size(), self.swa_available_size()) + raise NotImplementedError() def full_available_size(self): return self.full_attn_allocator.available_size() @@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ) return msg - def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0): - used_full = self.size_full - (self.full_available_size() + full_evictable_size) - used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size) - msg = ( - f"#token: full={used_full}, swa={used_swa}, " - f"token usage: full={used_full / self.size_full:.2f}, " - f"swa={used_swa / self.size_swa:.2f}, " - ) - return msg, used_full - def get_kvcache(self): return self._kvcache diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 1129226c3..4fdd04b72 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple +from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple import torch @@ -56,15 +56,27 @@ class BasePrefixCache(ABC): pass @abstractmethod - def dec_lock_ref(self, node: Any): + def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None): pass def evictable_size(self): return 0 + def full_evictable_size(self): + return 0 + + def swa_evictable_size(self): + return 0 + def protected_size(self): return 0 + def full_protected_size(self): + return 0 + + def swa_protected_size(self): + return 0 + def total_size(self): raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index a1e58aa3a..1cec3d21b 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache): def inc_lock_ref(self, node: Any): return 0 - def dec_lock_ref(self, node: Any): + def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None): return 0 def pretty_print(self): @@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache): super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size) assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) - def evict( + def evict_swa( self, req: Req, prelen: int, @@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache): ] self.token_to_kv_pool_allocator.free_swa(free_slots) req.evicted_seqlen_local = new_evicted_seqlen_local + + def evict(self, num_tokens: int): + pass diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py new file mode 100644 index 000000000..7a23eb856 --- /dev/null +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -0,0 +1,1025 @@ +from __future__ import annotations + +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +The radix tree data structure for managing the hybrid (full and SWA) KV cache. +""" + +import heapq +import time +from collections import defaultdict +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch + +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + +import logging + +logger = logging.getLogger(__name__) + + +class TreeNode: + + counter = 0 + swa_uuid_counter = 1 + + def __init__(self, id: Optional[int] = None): + self.children = defaultdict(TreeNode) + self.parent: TreeNode = None + self.key: List[int] = None + self.value: Optional[torch.Tensor] = None + # swa_tombstone is used to indicate the kv indices have been freed for swa layers + self.swa_tombstone = False + # invariant: for any node, if swa_lock_ref is locked, full_lock_ref must be locked; + # if full_lock_ref is locked, swa_lock_ref doesn't need to be locked. So, + # full_lock_ref is always >= swa_lock_ref. + self.full_lock_ref = 0 + self.swa_lock_ref = 0 + # last access time is only used for sanity check. LRU is maintained by the lru list. + self.last_access_time = time.monotonic() + + self.hit_count = 0 + # indicating the node is loading KV cache from host + self.loading = False + # store the host indices of KV cache + self.host_value = None + + # for lru list, invariant: + # 1. prev has greater last_access_time + # 2. next has smaller last_access_time + self.prev = None + self.next = None + self.swa_prev = None + self.swa_next = None + + self.id = TreeNode.counter if id is None else id + TreeNode.counter += 1 + self.swa_uuid = None + + @property + def evicted(self): + return self.value is None + + @property + def backuped(self): + return self.host_value is not None + + def __lt__(self, other: "TreeNode"): + return self.last_access_time < other.last_access_time + + +def _key_match_page_size1(key0: List, key1: List): + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def _key_match_paged(key0: List, key1: List, page_size: int): + min_len = min(len(key0), len(key1)) + + i = 0 + while i < min_len: + if key0[i : i + page_size] != key1[i : i + page_size]: + break + i += page_size + + return i + + +def gen_swa_uuid() -> int: + TreeNode.swa_uuid_counter += 1 + return TreeNode.swa_uuid_counter + + +class LRUList: + def __init__(self, swa: bool = False): + self.swa = swa + if self.swa: + self.prv = "swa_prev" + self.nxt = "swa_next" + self.lock_ref = "swa_lock_ref" + else: + self.prv = "prev" + self.nxt = "next" + self.lock_ref = "full_lock_ref" + # Initialize dummy head and tail nodes + self.head = TreeNode() # Most recently used side + self.tail = TreeNode() # Least recently used side + setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail + setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head + self.cache = {} + + def _add_node(self, node): + """Helper to add node right after head (most recently used)""" + self._add_node_after(self.head, node) + + def _add_node_after(self, old_node, new_node): + """Helper to add node right after old_node""" + setattr(new_node, self.prv, old_node) # new_node.prev = old_node + setattr( + new_node, self.nxt, getattr(old_node, self.nxt) + ) # new_node.next = old_node.next + setattr( + getattr(old_node, self.nxt), self.prv, new_node + ) # old_node.next.prev = new_node + setattr(old_node, self.nxt, new_node) # old_node.next = new_node + + def _remove_node(self, node): + """Helper to remove node from linked list""" + setattr( + getattr(node, self.prv), self.nxt, getattr(node, self.nxt) + ) # node.prev.next = node.next + setattr( + getattr(node, self.nxt), self.prv, getattr(node, self.prv) + ) # node.next.prev = node.prev + + def _get_lru(self) -> Optional[TreeNode]: + """ + Get the least recently used node + """ + if len(self.cache) == 0: + return None + return getattr(self.tail, self.prv) + + def reset_node_mru(self, node): + """ + Move a (existing) node to most recently used position + """ + assert node.id in self.cache, f"Resetting node {node.id=} not in lru list" + assert ( + not self.swa or not node.swa_tombstone + ), f"Resetting swa tombstone node in swa lru list: {node.id=}" + self._remove_node(node) + self._add_node(node) + + def reset_node_and_parents_mru(self, node, root_node): + """ + Move an (existing) node and its parents to most recently used position. Child node is + more recently used than parent node. + """ + prev_node = self.head + while node != root_node: + # for swa lru list, only reset non-tombstone nodes + if not self.swa or not node.swa_tombstone: + assert ( + node.id in self.cache + ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru" + self._remove_node(node) + self._add_node_after(prev_node, node) + prev_node = node + node = node.parent + + def insert_mru(self, node): + """ + Insert a (new) node as most recently used + """ + assert ( + not self.swa or not node.swa_tombstone + ), f"Inserting swa tombstone node in swa lru list: {node.id=}" + assert ( + node.id not in self.cache + ), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}" + self.cache[node.id] = node + self._add_node(node) + + def remove_node(self, node: TreeNode): + """ + Remove node from lru list + """ + assert node.id in self.cache, f"Removing node {node.id=} not in lru list" + assert ( + not self.swa or not node.swa_tombstone + ), f"Removing swa tombstone node from swa lru list: {node.id=}" + del self.cache[node.id] + self._remove_node(node) + + def get_lru_no_lock(self) -> Optional[TreeNode]: + """ + Get the least recently used node that is not locked + """ + return self.get_prev_no_lock(self.tail, check_id=False) + + def get_leaf_lru_no_lock(self) -> Optional[TreeNode]: + """ + Get the least recently used leaf node that is not locked + """ + return self.get_prev_leaf_no_lock(self.tail, check_id=False) + + def get_prev_no_lock( + self, node: TreeNode, check_id: bool = True + ) -> Optional[TreeNode]: + """ + Get the previous (i.e. more recently used) node that is not locked + """ + if check_id: + assert ( + node.id in self.cache + ), f"Getting prev of node {node.id=} not in lru list" + x = getattr(node, self.prv) # x = node.prev + while getattr(x, self.lock_ref) > 0: + x = getattr(x, self.prv) # x = x.prev + # if x is the head, it means there is no node in the lru list without lock + if x == self.head: + return None + return x + + def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True): + """ + Get the previous (i.e. more recently used) leaf node that is not locked + """ + if check_id: + assert ( + node.id in self.cache + ), f"Getting prev of node {node.id=} not in lru list" + x = getattr(node, self.prv) # x = node.prev + while getattr(x, self.lock_ref) > 0 or len(x.children) > 0: + x = getattr(x, self.prv) # x = x.prev + # if x is the head, it means there is no leaf node in the lru list without lock + if x == self.head: + return None + return x + + def in_list(self, node: Optional[TreeNode]): + """ + Check if the node is in the lru list + """ + if not node: + return False + return node.id in self.cache + + # Note: this is expensive, only use for debug + def sanity_check_evictable_size(self): + """ + Check the evictable size (i.e. the size of the nodes that are not locked) + """ + node = self.get_lru_no_lock() + evictable_size = 0 + while self.in_list(node): + evictable_size += len(node.value) + node = self.get_prev_no_lock(node) + return evictable_size + + # Note: this is expensive, only use for debug or idle check + def sanity_check(self, tree_cache: "SWARadixCache"): + """ + Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and + checking if the lru list is valid. + """ + try: + if self.swa: + nodes = tree_cache._collect_nontombstone_nodes() + else: + nodes = tree_cache._collect_all_nodes() + total_nodes = len(nodes) + total_lru_plus_1 = len(self.cache) + 1 + # heapify based on last_access_time + heapq.heapify(nodes) + # the root node is not in the lru list + assert ( + len(nodes) == len(self.cache) + 1 + ), f"len(nodes): {len(nodes)} != len(self.cache) + 1: {len(self.cache) + 1}" + + x_lru = self._get_lru() + while len(nodes): + x = heapq.heappop(nodes) + if x == tree_cache.root_node: + # root node is not in the lru list + continue + assert ( + x == x_lru + ), f"Incorrect LRU list, {self.swa=}, x: {x.id=} != x_lru: {x_lru.id=}" + assert ( + x_lru.full_lock_ref == 0 + ), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}" + assert ( + x_lru.swa_lock_ref == 0 + ), f"x_lru should not be locked when idle, {x_lru.swa_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}" + x_lru = getattr(x, self.prv) + + if self.swa: + evictable_size = tree_cache.swa_evictable_size() + lru_list_evictable_size = tree_cache.swa_lru_list_evictable_size() + else: + evictable_size = tree_cache.full_evictable_size() + lru_list_evictable_size = tree_cache.full_lru_list_evictable_size() + + assert ( + evictable_size == lru_list_evictable_size + ), f"{self.swa=}, total nodes: {total_nodes}, total lru plus 1: {total_lru_plus_1}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}" + except Exception as e: + msg = f"SWA Radix tree sanity check failed, ping @hanming-lu: {e}" + logger.error(msg) + raise Exception(msg) + + +class SWARadixCache(BasePrefixCache): + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: SWATokenToKVPoolAllocator, + sliding_window_size: int, + page_size: int, + disable: bool = False, + ): + assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.page_size = page_size + self.disable = disable + + if self.token_to_kv_pool_allocator: + self.device = self.token_to_kv_pool_allocator.device + else: + self.device = torch.device("cpu") + + if self.page_size == 1: + self.key_match_fn = _key_match_page_size1 + self.get_child_key_fn = lambda key: key[0] + else: + self.key_match_fn = partial(_key_match_paged, page_size=page_size) + self.get_child_key_fn = lambda key: tuple(key[:page_size]) + + self.sliding_window_size = sliding_window_size + self.reset() + + ##### Public API ##### + + def reset(self) -> None: + self.root_node = TreeNode() + self.root_node.key = [] + self.root_node.value = [] + self.root_node.full_lock_ref = 1 + self.root_node.swa_lock_ref = 1 + self.full_evictable_size_ = 0 + self.swa_evictable_size_ = 0 + self.full_protected_size_ = 0 + self.swa_protected_size_ = 0 + # LRU lists are used to maintain the order of eviction of the nodes in the tree + self.full_lru_list = LRUList(swa=False) + self.swa_lru_list = LRUList(swa=True) + + def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + """Find the matching prefix from the radix tree. + Args: + key: A list of token IDs to find a matching prefix. + Returns: + A tuple of a tensor of matching prefix token IDs and + the last node that contains the prefix values. Note that + this API can modify the internal state of the Radix tree. + The last node create a new child if the prefix is shorter + than the last node's value. + """ + if self.disable or len(key) == 0: + return MatchResult( + device_indices=torch.empty( + (0,), + dtype=torch.int64, + device=self.device, + ), + last_device_node=self.root_node, + last_host_node=self.root_node, + ) + + if self.page_size != 1: + page_aligned_len = len(key) // self.page_size * self.page_size + key = key[:page_aligned_len] + + value, last_node = self._match_prefix_helper(key) + if value: + value = torch.cat(value) + else: + value = torch.empty((0,), dtype=torch.int64, device=self.device) + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_node, + ) + + def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int: + if self.disable: + return 0 + + if value is None: + value = [x for x in key] + return self._insert_helper(self.root_node, key, value, prev_prefix_len) + + def cache_finished_req(self, req: Req) -> None: + """Cache request when it finishes.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, + : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0), + ] + self.token_to_kv_pool_allocator.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return + + token_ids = (req.origin_input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + if self.page_size != 1: + page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) + else: + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.clone() + + # Radix Cache takes one ref in memory pool + # insert the token_ids and kv_indices into the radix tree + # Note: the insert function already frees the overlapped kv_indices + new_prefix_len = self.insert( + token_ids[:page_aligned_len], + page_aligned_kv_indices, + len(req.prefix_indices), + ) + + # Remove req slot release the cache lock + self.req_to_token_pool.free(req.req_pool_idx) + self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) + + def cache_unfinished_req(self, req: Req) -> None: + """Cache request when it is unfinished.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ] + + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + req.prefix_indices = kv_indices + return + + token_ids = req.fill_ids + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + if self.page_size != 1: + page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + else: + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.clone() + page_aligned_token_ids = token_ids[:page_aligned_len] + + # Radix Cache takes one ref in memory pool + # Note: the insert function already frees the overlapped kv_indices + new_prefix_len = self.insert( + page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices) + ) + + # The prefix indices could be updated, reuse it + new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids) + assert len(req.prefix_indices) <= len( + new_indices + ), f"{req.prefix_indices=}, {new_indices=}" + assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}" + self.req_to_token_pool.write( + (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), + new_indices[len(req.prefix_indices) :], + ) + + self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) + swa_uuid_for_lock = self.inc_lock_ref(new_last_node) + + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + if self.page_size != 1: + req.prefix_indices = torch.cat( + [new_indices, kv_indices[len(new_indices) :]] + ) + else: + req.prefix_indices = new_indices + req.last_node = new_last_node + req.swa_uuid_for_lock = swa_uuid_for_lock + + def pretty_print(self) -> None: + self._print_helper(self.root_node, 0) + total_size, total_swa_size = self._total_size_helper() + print(f"#full_tokens: {total_size}, #swa_tokens: {total_swa_size}") + + def total_size(self) -> Tuple[int, int]: + return self._total_size_helper() + + def evict(self, full_num_tokens: int, swa_num_tokens: int = 0) -> None: + if self.disable: + return + + full_num_evicted = 0 + swa_num_evicted = 0 + if full_num_tokens > 0: + # get the least recently used leaf node that is not locked + x = self.full_lru_list.get_leaf_lru_no_lock() + + while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x): + assert ( + x != self.root_node + ), f"root node should not exist in full lru list, {x.id=}" + assert x.full_lock_ref == 0, f"node is in use, {x.id=}" + + # 1. free node kv indices, evict full and swa tokens + self.token_to_kv_pool_allocator.free(x.value) + full_num_evicted += len(x.value) + swa_num_evicted += len(x.value) + + # 2. get the next leaf, update the lru lists + x_next = self.full_lru_list.get_prev_leaf_no_lock(x) + self.full_lru_list.remove_node(x) + self.swa_lru_list.remove_node(x) + + # 3. delete the leaf node + self._delete_leaf(x) + + # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone + x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x) + full_num_evicted += leaf_full_num_evicted + + # 5. if parent has no more children, it is a leaf. It is possible that this node is lru, so + # we need to get the first leaf node in the lru list + if len(x.parent.children) == 0: + x_next = self.full_lru_list.get_leaf_lru_no_lock() + + x = x_next + + if swa_num_evicted < swa_num_tokens: + # get the least recently used node that is not locked, doesn't have to be a leaf + x = self.swa_lru_list.get_lru_no_lock() + + # evict lru leaf nodes until swa_num_tokens is reached + while swa_num_evicted < swa_num_tokens and (self.swa_lru_list.in_list(x)): + assert not x.swa_tombstone, f"duplicate swa tombstone node, {x.id=}" + assert x != self.root_node, f"root node is not evictable, {x.id=}" + assert x.swa_lock_ref == 0, f"node is in use by swa kv indices, {x.id=}" + + if len(x.children) > 0: + # 1. an internal node, free swa tokens. + self.token_to_kv_pool_allocator.free_swa(x.value) + swa_num_evicted += len(x.value) + + # 2. get the next node, update the lru lists + x_next = self.swa_lru_list.get_prev_no_lock(x) + self.swa_lru_list.remove_node(x) + + # 3. tombstone the node + self._tombstone_internal_node(x) + else: + assert ( + x.full_lock_ref == 0 + ), f"leaf node with full lock must also have swa lock, {x.id=}" + # 1. a leaf node, free full and swa tokens + self.token_to_kv_pool_allocator.free(x.value) + full_num_evicted += len(x.value) + swa_num_evicted += len(x.value) + + # 2. get the next node, update the lru lists + x_next = self.swa_lru_list.get_prev_no_lock(x) + self.full_lru_list.remove_node(x) + self.swa_lru_list.remove_node(x) + + # 3. delete the leaf node + self._delete_leaf(x) + + # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone + self._iteratively_delete_tombstone_leaf(x) + + x = x_next + + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: + """ + Increment the lock reference count for the node. Returns the swa_uuid_for_lock, which needs + to be passed to dec_lock_ref. + It locks the full_lock_ref for nodes between the [last node, root), exclusive. + It locks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive. + """ + if self.disable: + return None + + swa_lock_size = 0 + swa_uuid_for_lock = None + while node != self.root_node: + # lock full from node to root + assert ( + node.full_lock_ref >= 0 + ), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}" + if node.full_lock_ref == 0: + self.full_evictable_size_ -= len(node.value) + self.full_protected_size_ += len(node.value) + node.full_lock_ref += 1 + + # lock swa if we have not reached the sliding window size. + # When we reach the sliding window size, we will set the swa_uuid_for_lock. + # caller needs to pass the swa_uuid_for_lock to dec_lock_ref + if swa_lock_size < self.sliding_window_size: + assert ( + not node.swa_tombstone + ), f"inc_lock_swa on swa_tombstone node, {node.id=}" + if node.swa_lock_ref == 0: + self.swa_evictable_size_ -= len(node.value) + self.swa_protected_size_ += len(node.value) + node.swa_lock_ref += 1 + swa_lock_size += len(node.value) + if swa_lock_size >= self.sliding_window_size: + if node.swa_uuid is None: + node.swa_uuid = gen_swa_uuid() + swa_uuid_for_lock = node.swa_uuid + node = node.parent + return swa_uuid_for_lock + + def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None): + """ + Decrement the lock reference count for the node. + It unlocks the full_lock_ref for nodes between the [last node, root), exclusive. + It unlocks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive. + If swa_uuid_for_lock is None, it unlocks to the root, exclusive. + """ + if self.disable: + return + + dec_lock_swa = True + while node != self.root_node: + assert ( + node.full_lock_ref > 0 + ), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}" + if node.full_lock_ref == 1: + self.full_evictable_size_ += len(node.value) + self.full_protected_size_ -= len(node.value) + node.full_lock_ref -= 1 + + if dec_lock_swa: + assert ( + not node.swa_tombstone + ), f"dec_lock_ref on swa_tombstone node, {node.id=}" + assert ( + node.swa_lock_ref > 0 + ), f"dec_lock_ref on node with {node.swa_lock_ref=}, {node.id=}" + + if node.swa_lock_ref == 1: + self.swa_evictable_size_ += len(node.value) + self.swa_protected_size_ -= len(node.value) + node.swa_lock_ref -= 1 + if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock: + dec_lock_swa = False + + node = node.parent + + def sanity_check(self): + self.full_lru_list.sanity_check(self) + self.swa_lru_list.sanity_check(self) + + def evictable_size(self) -> Tuple[int, int]: + # Note: use full_evictable_size() and swa_evictable_size() instead. + raise NotImplementedError + + def full_evictable_size(self) -> int: + return self.full_evictable_size_ + + def swa_evictable_size(self) -> int: + return self.swa_evictable_size_ + + # Note: this is expensive, only use for debug + def full_lru_list_evictable_size(self) -> int: + return self.full_lru_list.sanity_check_evictable_size() + + # Note: this is expensive, only use for debug + def swa_lru_list_evictable_size(self) -> int: + return self.swa_lru_list.sanity_check_evictable_size() + + def protected_size(self) -> Tuple[int, int]: + # Note: use full_protected_size() and swa_protected_size() instead. + raise NotImplementedError + + def full_protected_size(self) -> int: + # protected size refers to the size of the full cache that is locked + return self.full_protected_size_ + + def swa_protected_size(self) -> int: + # protected size refers to the size of the swa cache that is locked + return self.swa_protected_size_ + + def all_values_flatten(self) -> torch.Tensor: + values = [] + + def _dfs_helper(node: TreeNode): + for _, child in node.children.items(): + values.append(child.value) + _dfs_helper(child) + + _dfs_helper(self.root_node) + return torch.cat(values) + + ##### Internal Helper Functions ##### + + def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]: + """ + SWA prefix matching helper. It factors in the sliding window size such that + the matched node is guaranteed to either 1. connected to root without swa tombstone, + or 2. the number of matching tokens from the matched node to the last swa tombstone + node is greater than or equal to the sliding window size. + """ + node = self.root_node + child_key = self.get_child_key_fn(key) + + value = [] + # for path connected to root without tombstone, always match, so set to inf + match_len_since_tombstone = float("inf") + best_value_len = 0 + best_last_node = node + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] + + # update best_value_len and best_last_node if needed + if ( + child.swa_tombstone + and match_len_since_tombstone >= self.sliding_window_size + ): + best_value_len = len(value) + best_last_node = node + match_len_since_tombstone = 0 + + prefix_len = self.key_match_fn(child.key, key) + if prefix_len < len(child.key): + new_node = self._split_node(child.key, child, prefix_len) + value.append(new_node.value) + if not new_node.swa_tombstone: + match_len_since_tombstone += len(new_node.value) + node = new_node + break + else: + value.append(child.value) + if not child.swa_tombstone: + match_len_since_tombstone += len(child.value) + node = child + key = key[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) + + # handle best_value_len and best_last_node, for the case that last node is fully matched + if match_len_since_tombstone >= self.sliding_window_size: + best_value_len = len(value) + best_last_node = node + + # update time for matched nodes, and make nodes closer to root to be least recently used + # this allows swa to evict nodes closer to root first + self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) + self.swa_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) + + # This last_access_time is for sanity check, can be deleted after validation in production + cur_time = time.monotonic() + while node: + node.last_access_time = cur_time + cur_time -= 0.0001 + node = node.parent + + return value[:best_value_len], best_last_node + + def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode: + # new_node -> child + new_node = TreeNode() + new_node.children = {self.get_child_key_fn(key[split_len:]): child} + new_node.parent = child.parent + new_node.swa_tombstone = child.swa_tombstone + new_node.full_lock_ref = child.full_lock_ref + new_node.swa_lock_ref = child.swa_lock_ref + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len] + # parent inherits the swa_uuid from child for swa lock ref + new_node.swa_uuid = child.swa_uuid + child.swa_uuid = None + # child time should be later than parent's time for swa tombstone + child.last_access_time = time.monotonic() + + # remove the child from the lru lists because it is being split + self.full_lru_list.remove_node(child) + if not new_node.swa_tombstone: + self.swa_lru_list.remove_node(child) + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:] + new_node.parent.children[self.get_child_key_fn(key)] = new_node + + # insert the new node and child into the lru lists, insert + # parent first so that parent is after child in the lru list + self.full_lru_list.insert_mru(new_node) + self.full_lru_list.insert_mru(child) + if not new_node.swa_tombstone: + self.swa_lru_list.insert_mru(new_node) + self.swa_lru_list.insert_mru(child) + return new_node + + def _insert_helper( + self, node: TreeNode, key: List, value, update_kv_after_len: int + ) -> int: + # Update the last access time from root to leaf, so that + # swa will tombstone the node closer to root first + node.last_access_time = time.monotonic() + if node != self.root_node: + self.full_lru_list.reset_node_mru(node) + if not node.swa_tombstone: + self.swa_lru_list.reset_node_mru(node) + if len(key) == 0: + return 0 + + child_key = self.get_child_key_fn(key) + + total_prefix_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + self.full_lru_list.reset_node_mru(node) + if not node.swa_tombstone: + self.swa_lru_list.reset_node_mru(node) + prefix_len = self.key_match_fn(node.key, key) + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + # if tombstone after update_kv_after_len, update node.value to be the input value. + # This is needed because it is possible that the last sliding window size tokens + # contains tombstone. If this is the case and we don't update the kv value, then + # the prefill prefix matching will stuck. + if update_kv_after_len < total_prefix_length + prefix_len: + first_diff_idx = max(0, update_kv_after_len - total_prefix_length) + if node.swa_tombstone: + assert ( + node.swa_lock_ref == 0 + ), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}" + self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:]) + node.value = value[:prefix_len] + node.swa_tombstone = False + + # insert the node into the lru lists + self.swa_lru_list.insert_mru(node) + + self.swa_evictable_size_ += len(node.value) + else: + self.token_to_kv_pool_allocator.free( + value[first_diff_idx:prefix_len] + ) + + total_prefix_length += prefix_len + key = key[prefix_len:] + value = value[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) + + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + self.full_lru_list.insert_mru(new_node) + self.swa_lru_list.insert_mru(new_node) + node.children[child_key] = new_node + self.full_evictable_size_ += len(value) + self.swa_evictable_size_ += len(value) + return total_prefix_length + + def _iteratively_delete_tombstone_leaf( + self, node: TreeNode + ) -> Tuple[TreeNode, int]: + full_num_evicted = 0 + while node.parent.swa_tombstone and len(node.parent.children) == 0: + # root node is not evictable + if node.parent == self.root_node: + break + # if locked, means node is in use, skip + if node.parent.full_lock_ref > 0: + break + assert ( + node.parent.swa_lock_ref == 0 + ), f"tombstone swa_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.swa_lock_ref=}, {node.parent.id=}" + # delete tombstone node evicts full tokens + self.token_to_kv_pool_allocator.free(node.parent.value) + full_num_evicted += len(node.parent.value) + self.full_lru_list.remove_node(node.parent) + self._delete_tombstone_leaf(node.parent) + node = node.parent + + return node, full_num_evicted + + def _delete_leaf(self, node: TreeNode) -> None: + assert ( + not node.swa_tombstone + ), f"Invariant violated: leaf node is a tombstone, {node.id=}" + assert len(node.children) == 0, f"leaf node has children, {node.id=}" + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.full_evictable_size_ -= len(node.key) + self.swa_evictable_size_ -= len(node.key) + + def _tombstone_internal_node(self, node: TreeNode) -> None: + assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" + node.swa_tombstone = True + self.swa_evictable_size_ -= len(node.key) + + def _delete_tombstone_leaf(self, node: TreeNode) -> None: + assert ( + node.swa_tombstone + ), f"Deleting a unexpected non-tombstone leaf node, {node.id=}" + assert len(node.children) == 0, f"leaf node has children, {node.id=}" + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.full_evictable_size_ -= len(node.key) + + def _collect_leaves(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if len(cur_node.children) == 0: + ret_list.append(cur_node) + else: + stack.extend(cur_node.children.values()) + + return ret_list + + def _collect_nontombstone_nodes(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if not cur_node.swa_tombstone: + ret_list.append(cur_node) + stack.extend(cur_node.children.values()) + + return ret_list + + def _collect_all_nodes(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + while stack: + cur_node = stack.pop() + ret_list.append(cur_node) + stack.extend(cur_node.children.values()) + return ret_list + + def _print_helper(self, node: TreeNode, indent: int) -> None: + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + current_node.id, + len(current_node.key), + f"fr={current_node.full_lock_ref}", + f"sr={current_node.swa_lock_ref}", + f"fll={self.full_lru_list.in_list(current_node)}", + f"sll={self.swa_lru_list.in_list(current_node)}", + f"ts={current_node.swa_tombstone}", + ) + for key, child in current_node.children.items(): + stack.append((child, current_indent + 2)) + + assert key == self.get_child_key_fn( + child.key + ), f"{key=}, {self.get_child_key_fn(child.key)=}" + + def _total_size_helper(self) -> Tuple[int, int]: + total_size = 0 + total_swa_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + if not current_node.swa_tombstone: + total_swa_size += len(current_node.value) + for child in current_node.children.values(): + if child.evicted: + continue + stack.append(child) + return total_size, total_swa_size diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 051f2b75e..fe9560497 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -275,6 +275,15 @@ class ModelRunner: self.sampler = Sampler() self.load_model() + if ( + not self.server_args.disable_hybrid_swa_memory + and self.sliding_window_size is not None + and self.sliding_window_size > 0 + ): + architectures = self.model_config.hf_config.architectures + if architectures and not any("Llama4" in arch for arch in architectures): + self.is_hybrid = self.model_config.is_hybrid = True + self.start_layer = getattr(self.model, "start_layer", 0) self.end_layer = getattr( self.model, "end_layer", self.model_config.num_hidden_layers @@ -471,10 +480,6 @@ class ModelRunner: if self.model_config.context_len > 8192: self.mem_fraction_static *= 0.85 - if self.is_hybrid and not server_args.disable_radix_cache: - logger.info("Automatically disable radix cache for hybrid cache.") - server_args.disable_radix_cache = True - def init_torch_distributed(self): logger.info("Init torch distributed begin.") @@ -645,11 +650,15 @@ class ModelRunner: ) # Parse other args - self.sliding_window_size = ( - self.model.get_attention_sliding_window_size() - if hasattr(self.model, "get_attention_sliding_window_size") - else None - ) + self.sliding_window_size = None + if hasattr(self.model, "get_attention_sliding_window_size"): + self.sliding_window_size = self.model.get_attention_sliding_window_size() + elif self.model_config.attention_chunk_size is not None: + self.sliding_window_size = self.model_config.attention_chunk_size + print( + f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}" + ) + self.dtype = self.model_config.dtype after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) @@ -992,8 +1001,53 @@ class ModelRunner: ) self.max_total_num_tokens = self.full_max_total_num_tokens else: - raise ValueError( - f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}." + assert self.sliding_window_size is not None and self.sliding_window_size > 0 + full_attention_layer_ids = [] + swa_attention_layer_ids = [] + + try: + layers = self.model.model.layers + except: + try: + layers = self.model.language_model.model.layers + except: + self.is_hybrid = False + return + + for layer in layers: + if ( + layer.self_attn.attn.sliding_window_size is None + or layer.self_attn.attn.sliding_window_size == -1 + ): + full_attention_layer_ids.append(layer.layer_id) + else: + swa_attention_layer_ids.append(layer.layer_id) + self.model_config.swa_attention_layer_ids = swa_attention_layer_ids + self.model_config.full_attention_layer_ids = full_attention_layer_ids + + # Algorithm: + # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens. + # - Find total # of tokens available across layers. + # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio. + total_tokens = ( + self.max_total_num_tokens * self.model_config.num_hidden_layers + ) + full_layers_num = len(full_attention_layer_ids) + swa_layers_num = len(swa_attention_layer_ids) + swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio + + # Solve the equations: + # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens + # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens + denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num + self.full_max_total_num_tokens = int(total_tokens / denominator) + self.swa_max_total_num_tokens = int( + self.full_max_total_num_tokens * swa_full_tokens_ratio + ) + self.max_total_num_tokens = self.full_max_total_num_tokens + + logger.info( + f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}" ) def init_memory_pool( @@ -1072,7 +1126,6 @@ class ModelRunner: // self.server_args.page_size * self.server_args.page_size ) - # create token size for hybrid cache if self.is_hybrid: self.set_num_token_hybrid() diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 9056b0b0c..9ee892bb7 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.layer_id = layer_id self.hidden_size = config.hidden_size self.self_attn = Gemma2Attention( layer_id=layer_id, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a59bf815d..30191ee08 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -63,6 +63,7 @@ class ServerArgs: enable_multimodal: Optional[bool] = None revision: Optional[str] = None hybrid_kvcache_ratio: Optional[float] = None + swa_full_tokens_ratio: float = 0.8 impl: str = "auto" # Port for the HTTP server @@ -225,6 +226,7 @@ class ServerArgs: enable_return_hidden_states: bool = False enable_triton_kernel_moe: bool = False warmups: Optional[str] = None + disable_hybrid_swa_memory: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -481,14 +483,22 @@ class ServerArgs: model_arch = get_model_arch(self) - # Auto set draft_model_path DeepSeek-V3/R1 if model_arch == "DeepseekV3ForCausalLM": + # Auto set draft_model_path DeepSeek-V3/R1 if self.speculative_draft_model_path is None: self.speculative_draft_model_path = self.model_path else: logger.warning( "DeepSeek MTP does not require setting speculative_draft_model_path." ) + elif "Llama4" in model_arch: + # TODO: remove this after Llama4 supports in other backends + if self.attention_backend != "fa3": + self.attention_backend = "fa3" + logger.warning( + "Llama4 requires using fa3 attention backend. " + "Attention backend is automatically set to fa3." + ) # Auto choose parameters if self.speculative_num_steps is None: @@ -852,6 +862,18 @@ class ServerArgs: "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)" ), ) + parser.add_argument( + "--swa-full-tokens-ratio", + type=float, + default=ServerArgs.swa_full_tokens_ratio, + help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. " + "E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.", + ) + parser.add_argument( + "--disable-hybrid-swa-memory", + action="store_true", + help="Disable the hybrid SWA memory.", + ) # Other runtime options parser.add_argument( @@ -1730,10 +1752,6 @@ class ServerArgs: else: self.lora_paths[lora_path] = lora_path - model_arch = get_model_arch(self) - if "Llama4" in model_arch and self.hybrid_kvcache_ratio is not None: - assert self.attention_backend == "fa3" - def prepare_server_args(argv: List[str]) -> ServerArgs: """ diff --git a/test/srt/test_swa_unittest.py b/test/srt/test_swa_unittest.py new file mode 100644 index 000000000..e026d70af --- /dev/null +++ b/test/srt/test_swa_unittest.py @@ -0,0 +1,176 @@ +import unittest + +import torch + +from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import SWARadixCache + + +class TestSWA(unittest.TestCase): + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + + def test_swa_memory_pool(self): + size = 16 + size_swa = 16 + num_head = 8 + head_dim = 128 + num_layers = 48 + global_interval = 4 + dtype = torch.bfloat16 + device = "cuda" + full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] + full_attention_layer_ids_set = set(full_attention_layer_ids) + swa_attention_layer_ids = [ + i for i in range(num_layers) if i not in full_attention_layer_ids_set + ] + pool = SWAKVPool( + size, + size_swa, + dtype, + num_head, + head_dim, + swa_attention_layer_ids, + full_attention_layer_ids, + device, + ) + alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool) + assert alloc.available_size() == size + size_swa + index = alloc.alloc(1) + assert alloc.available_size() == size_swa + size_swa - 2 + alloc.free_swa(index) + result = alloc.translate_loc_from_full_to_swa(index) + print(result) + + def test_swa_radix_cache_1(self): + # args + req_size = 10 + max_context_len = 128 + kv_size = 128 + kv_size_swa = 64 + sliding_window_size = 4 + num_head = 8 + head_dim = 128 + num_layers = 48 + global_interval = 4 + dtype = torch.bfloat16 + device = "cuda" + full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] + full_attention_layer_ids_set = set(full_attention_layer_ids) + swa_attention_layer_ids = [ + i for i in range(num_layers) if i not in full_attention_layer_ids_set + ] + # setup req to token pool + req_to_token_pool = ReqToTokenPool( + size=req_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + ) + # setup kv pool + kv_pool = SWAKVPool( + kv_size, + kv_size_swa, + dtype, + num_head, + head_dim, + swa_attention_layer_ids, + full_attention_layer_ids, + device, + ) + # setup token to kv pool allocator + allocator = SWATokenToKVPoolAllocator( + kv_size, kv_size_swa, dtype, device, kv_pool + ) + # setup radix cache + tree = SWARadixCache( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, + sliding_window_size=sliding_window_size, + page_size=1, + disable=False, + ) + + # test + print( + f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3) + assert len(req1_token_ids) == len(req1_kv_indices) + print( + f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" + ) + prefix_len = tree.insert(req1_token_ids, req1_kv_indices) + print( + f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7) + assert len(req2_token_ids) == len(req2_kv_indices) + print( + f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" + ) + prefix_len = tree.insert(req2_token_ids, req2_kv_indices) + print( + f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3) + assert len(req3_token_ids) == len(req3_kv_indices) + print( + f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" + ) + prefix_len = tree.insert(req3_token_ids, req3_kv_indices) + print( + f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7) + assert len(req4_token_ids) == len(req4_kv_indices) + print( + f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" + ) + prefix_len = tree.insert(req4_token_ids, req4_kv_indices) + print( + f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" + ) + + tree.pretty_print() + full_num_tokens, swa_num_tokens = 1, 0 + print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token") + tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens) + tree.pretty_print() + + full_num_tokens, swa_num_tokens = 0, 1 + print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token") + tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens) + tree.pretty_print() + + full_num_tokens, swa_num_tokens = 1, 2 + print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token") + tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens) + tree.pretty_print() + + req5_token_ids = [1, 2, 3, 4, 5] + kv_indices, last_node = tree.match_prefix(req5_token_ids) + print( + f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 0 + + req6_token_ids = [1, 2, 3, 4, 5, 60, 70] + kv_indices, last_node = tree.match_prefix(req6_token_ids) + print( + f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 7 + assert len(last_node.key) == 2 + assert last_node.key[0] == 60 + assert last_node.key[1] == 70 + + +if __name__ == "__main__": + unittest.main()