From 47367b768daf1b302a72f03a362fe8a100938847 Mon Sep 17 00:00:00 2001 From: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com> Date: Thu, 19 Jun 2025 09:58:48 -0700 Subject: [PATCH] [Refactor] Clean up radix cache related API (#7303) Co-authored-by: Zhiqiang Xie --- python/sglang/srt/managers/schedule_batch.py | 33 +++++----- python/sglang/srt/managers/schedule_policy.py | 58 ++++++++---------- python/sglang/srt/managers/scheduler.py | 26 +++----- .../sglang/srt/mem_cache/base_prefix_cache.py | 60 ++++++++++++++++--- python/sglang/srt/mem_cache/chunk_cache.py | 20 +++---- python/sglang/srt/mem_cache/hiradix_cache.py | 50 +++++++++------- python/sglang/srt/mem_cache/radix_cache.py | 28 +++++---- 7 files changed, 153 insertions(+), 122 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 28e1e33b8..74a396aa3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -38,7 +38,7 @@ import logging import threading from enum import Enum, auto from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -436,7 +436,7 @@ class Req: self, rid: str, origin_input_text: str, - origin_input_ids: Tuple[int], + origin_input_ids: List[int], sampling_params: SamplingParams, return_logprob: bool = False, top_logprobs_num: int = 0, @@ -467,7 +467,7 @@ class Req: # Each decode stage's output ids self.output_ids = [] # fill_ids = origin_input_ids + output_ids. Updated if chunked. - self.fill_ids = None + self.fill_ids = [] self.session_id = session_id self.input_embeds = input_embeds @@ -519,13 +519,14 @@ class Req: # Prefix info # The indices to kv cache for the shared prefix. - self.prefix_indices = [] + self.prefix_indices: torch.Tensor = [] # Number of tokens to run prefill. self.extend_input_len = 0 # The relative logprob_start_len in an extend batch self.extend_logprob_start_len = 0 - self.last_node = None - self.last_node_global = None + self.last_node: Any = None + self.last_host_node: Any = None + self.host_hit_length = 0 # Whether or not if it is chunked. It increments whenever # it is chunked, and decrement whenever chunked request is @@ -644,21 +645,17 @@ class Req: def init_next_round_input( self, tree_cache: Optional[BasePrefixCache] = None, - enable_hierarchical_cache=False, ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: - # tree cache is None if the prefix is not computed with tree cache. - if enable_hierarchical_cache: - self.prefix_indices, self.last_node, self.last_node_global = ( - tree_cache.match_prefix( - key=self.adjust_max_prefix_ids(), include_evicted=True - ) - ) - else: - self.prefix_indices, self.last_node = tree_cache.match_prefix( - rid=self.rid, key=self.adjust_max_prefix_ids() - ) + ( + self.prefix_indices, + self.last_node, + self.last_host_node, + self.host_hit_length, + ) = tree_cache.match_prefix( + key=self.adjust_max_prefix_ids(), + ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 1dd07881e..08e8b2a8a 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -90,7 +90,7 @@ class SchedulePolicy: def calc_priority(self, waiting_queue: List[Req]) -> bool: if self.policy == CacheAgnosticPolicy.FCFS: # A shortcut for FCFS - return + return False policy = self._determine_active_policy(waiting_queue) @@ -134,7 +134,7 @@ class SchedulePolicy: """ try: policy_enum = CacheAwarePolicy(policy) - if tree_cache.disable: + if getattr(tree_cache, "disable", True): # If tree_cache is disabled, using CacheAgnosticPolicy policy return CacheAgnosticPolicy.FCFS return policy_enum @@ -158,14 +158,9 @@ class SchedulePolicy: prefix_ids = r.adjust_max_prefix_ids() # NOTE: the prefix_indices must always be aligned with last_node - if self.enable_hierarchical_cache: - r.prefix_indices, r.last_node, r.last_node_global = ( - self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True) - ) - else: - r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=prefix_ids - ) + r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = ( + self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids) + ) # NOTE(sang): This logic is for in-batch prefix caching; # If there are more than 1 request that have small matching prefix from @@ -175,7 +170,7 @@ class SchedulePolicy: # threshold means we cannot use in-batch prefix caching for short prefixes. # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: - in_batch_matching_prefixes, _ = ( + in_batch_matching_prefixes, _, _, _ = ( self.waiting_queue_radix_tree.match_prefix( rid=r.rid, key=prefix_ids ) @@ -268,6 +263,7 @@ class AddReqResult(Enum): class PrefillAdder: def __init__( self, + page_size: int, tree_cache: BasePrefixCache, token_to_kv_pool_allocator: TokenToKVPoolAllocator, running_batch: ScheduleBatch, @@ -276,6 +272,7 @@ class PrefillAdder: rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): + self.page_size = page_size self.tree_cache = tree_cache self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.running_batch = running_batch @@ -442,46 +439,43 @@ class PrefillAdder: return self.budget_state() - def add_one_req( - self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False - ): + def add_one_req(self, req: Req, has_chunked_req: bool): if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): return self.add_one_req_ignore_eos(req, has_chunked_req) total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION ) - input_tokens = ( - -(-req.extend_input_len // self.tree_cache.page_size) - * self.tree_cache.page_size - ) + + # adjusting the input_tokens based on host_hit_length and page_size + real_input_tokens = req.extend_input_len - req.host_hit_length + real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size prefix_len = len(req.prefix_indices) if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN - if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: + if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0: return AddReqResult.OTHER with self._lock_node(req.last_node): - if total_tokens > self.rem_total_tokens: + # self.rem_total_tokens may decrease after the lock acquisition + if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN - if ( - enable_hierarchical_cache - and req.last_node_global is not None - and req.last_node_global.evicted - ): - req.last_node, req.prefix_indices = self.tree_cache.init_load_back( - req.last_node_global, req.prefix_indices + if req.host_hit_length > 0: + new_indices, req.last_node = self.tree_cache.init_load_back( + req.last_host_node, req.host_hit_length ) + req.prefix_indices = torch.cat([req.prefix_indices, new_indices]) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) - input_tokens = ( - -(-req.extend_input_len // self.tree_cache.page_size) - * self.tree_cache.page_size - ) prefix_len = len(req.prefix_indices) + input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size + + if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0: + return AddReqResult.OTHER + if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: # Non-chunked prefill self.can_run_list.append(req) @@ -496,7 +490,7 @@ class PrefillAdder: ) else: # Make sure at least one page is available - trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1 + trunc_len = self.rem_chunk_tokens - self.page_size + 1 if trunc_len <= 0: return AddReqResult.OTHER diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index eae5a61fc..b2484f090 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1467,15 +1467,14 @@ class Scheduler( return None if self.enable_hierarchical_cache: - # check for completion of hierarchical cache activities to release memory - self.tree_cache.writing_check() - self.tree_cache.loading_check() + self.tree_cache.check_hicache_events() # Get priority queue - prefix_computed = self.policy.calc_priority(self.waiting_queue) + self.policy.calc_priority(self.waiting_queue) # Prefill policy adder = PrefillAdder( + self.page_size, self.tree_cache, self.token_to_kv_pool_allocator, self.running_batch, @@ -1517,19 +1516,8 @@ class Scheduler( self.running_batch.batch_is_full = True break - # bypass prefix_computed if enable_hierarchical_cache - req.init_next_round_input( - ( - None - if (prefix_computed and not self.enable_hierarchical_cache) - else self.tree_cache - ), - self.enable_hierarchical_cache, - ) - - res = adder.add_one_req( - req, self.chunked_req, self.enable_hierarchical_cache - ) + req.init_next_round_input(self.tree_cache) + res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: @@ -1581,7 +1569,9 @@ class Scheduler( ) if self.enable_hierarchical_cache: # todo (zhiqiang): disable cuda graph execution if hicache loading triggered - new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache() + new_batch.hicache_consumer_index = ( + self.tree_cache.ready_to_load_host_cache() + ) new_batch.prepare_for_extend() diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 2035bbdbf..1129226c3 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,31 @@ from abc import ABC, abstractmethod -from typing import Any, List, Tuple +from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple + +import torch + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req +else: + Req = Any # Placeholder for Req type when not type checking + + +class MatchResult(NamedTuple): + """Result of a prefix match operation. + + Attributes: + device_indices : Indices of the KV cache on the device matched by common prefix. + last_device_node: The last TreeNode on the device that was matched. + last_host_node : The last TreeNode on the host that was matched. + Note that if HiCache is not enabled, + this **must** be the same as `last_device_node`. + host_hit_length : Length of the KV cache hit on the host, if applicable. + 0 if HiCache is not enabled. + """ + + device_indices: torch.Tensor + last_device_node: Any + last_host_node: Any + host_hit_length: int = 0 class BasePrefixCache(ABC): @@ -10,19 +36,15 @@ class BasePrefixCache(ABC): pass @abstractmethod - def match_prefix(self, **kwargs) -> Tuple[List[int], int]: + def match_prefix(self, key: List[int], **kwargs) -> MatchResult: pass @abstractmethod - def insert(self, **kwargs): + def cache_finished_req(self, req: Req, **kwargs): pass @abstractmethod - def cache_finished_req(self, **kwargs): - pass - - @abstractmethod - def cache_unfinished_req(self, **kwargs): + def cache_unfinished_req(self, req: Req, **kwargs): pass @abstractmethod @@ -49,5 +71,27 @@ class BasePrefixCache(ABC): def pretty_print(self): raise NotImplementedError() + def init_load_back( + self, + last_host_node: Any, + host_hit_length: int, + ) -> Tuple[torch.Tensor, Any]: + """ + Preparing KV cache loading from host to device. + """ + raise NotImplementedError() + + def ready_to_load_host_cache(self) -> Any: + """ + Notify the cache controller to start the KV cache loading + """ + raise NotImplementedError() + + def check_hicache_events(self) -> Any: + """ + Check HiCache related activities to update radix tree and synchronize across TP workers if needed + """ + raise NotImplementedError() + def take_events(self): return [] diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 1fe23c4e3..80bdb9690 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple import torch -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req -class ChunkCacheEntry: - def __init__(self, rid: str, value: torch.Tensor): - self.rid = rid - self.value = value - - class ChunkCache(BasePrefixCache): def __init__( self, @@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.page_size = page_size - self.disable = True def reset(self): pass - def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]: - return [], None + def match_prefix(self, **unused_kwargs) -> MatchResult: + return MatchResult( + device_indices=torch.empty((0,), dtype=torch.int64), + last_device_node=None, + last_host_node=None, + ) def cache_finished_req(self, req: Req): kv_indices = self.req_to_token_pool.req_to_token[ @@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache): # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later req.prefix_indices = kv_indices - def insert(self): - raise NotImplementedError() - def evict(self, num_tokens: int): pass diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 4d6c0ae11..cf7357bc0 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -7,6 +7,7 @@ from typing import List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController +from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, @@ -283,41 +284,44 @@ class HiRadixCache(RadixCache): def init_load_back( self, last_node: TreeNode, - prefix_indices: torch.Tensor, + host_hit_length: int, mem_quota: Optional[int] = None, ): - assert ( - len(prefix_indices) == 0 or prefix_indices.is_cuda - ), "indices of device kV caches should be on GPU" + _ = host_hit_length # unused, but kept for compatibility if last_node.evicted: loading_values = self.load_back(last_node, mem_quota) if loading_values is not None: - prefix_indices = ( - loading_values - if len(prefix_indices) == 0 - else torch.cat([prefix_indices, loading_values]) - ) logger.debug( f"loading back {len(loading_values)} tokens for node {last_node.id}" ) + return loading_values, last_node while last_node.evicted: last_node = last_node.parent - return last_node, prefix_indices + return ( + torch.empty((0,), dtype=torch.int64, device=self.device), + last_node, + ) - def ready_to_load_cache(self): + def ready_to_load_host_cache(self): producer_index = self.cache_controller.layer_done_counter.next_producer() self.load_cache_event.set() return producer_index - def match_prefix(self, key: List[int], include_evicted=False, **kwargs): + def check_hicache_events(self): + self.writing_check() + self.loading_check() + + def match_prefix(self, key: List[int], **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) if self.disable or len(key) == 0: - if include_evicted: - return empty_value, self.root_node, self.root_node - else: - return empty_value, self.root_node + return MatchResult( + device_indices=empty_value, + last_device_node=self.root_node, + last_host_node=self.root_node, + host_hit_length=0, + ) if self.page_size != 1: page_aligned_len = len(key) // self.page_size * self.page_size @@ -329,14 +333,18 @@ class HiRadixCache(RadixCache): else: value = empty_value - last_node_global = last_node + host_hit_length = 0 + last_host_node = last_node while last_node.evicted: + host_hit_length += len(last_node.host_value) last_node = last_node.parent - if include_evicted: - return value, last_node, last_node_global - else: - return value, last_node + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_host_node, + host_hit_length=host_hit_length, + ) def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.monotonic() diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 377784302..256595b7a 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import ( BlockStored, KVCacheEvent, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator if TYPE_CHECKING: @@ -47,9 +46,9 @@ class TreeNode: def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) - self.parent = None - self.key = None - self.value = None + self.parent: TreeNode = None + self.key: List[int] = None + self.value: Optional[torch.Tensor] = None self.lock_ref = 0 self.last_access_time = time.monotonic() @@ -57,7 +56,7 @@ class TreeNode: # indicating the node is loading KV cache from host self.loading = False # store the host indices of KV cache - self.host_value = None + self.host_value: Optional[torch.Tensor] = None self.id = TreeNode.counter if id is None else id TreeNode.counter += 1 @@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache): self.protected_size_ = 0 self._record_all_cleared_event() - def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: + def match_prefix(self, key: List[int], **kwargs) -> MatchResult: """Find the matching prefix from the radix tree. Args: key: A list of token IDs to find a matching prefix. @@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache): than the last node's value. """ if self.disable or len(key) == 0: - return ( - torch.empty( + return MatchResult( + device_indices=torch.empty( (0,), dtype=torch.int64, device=self.device, ), - self.root_node, + last_device_node=self.root_node, + last_host_node=self.root_node, ) if self.page_size != 1: @@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache): value = torch.cat(value) else: value = torch.empty((0,), dtype=torch.int64, device=self.device) - return value, last_node + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_node, + ) def insert(self, key: List, value=None): if self.disable: @@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache): ) # The prefix indices could be updated, reuse it - new_indices, new_last_node = self.match_prefix(page_aligned_token_ids) + new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids) self.req_to_token_pool.write( (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), new_indices[len(req.prefix_indices) :],