From 3d40794fcf3678a713c0054ae9d59dafab979bcf Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 24 Sep 2025 23:43:53 -0700 Subject: [PATCH] [HiCache] Cleaning the deprecated host memory state (#10778) --- .../sglang/srt/managers/cache_controller.py | 27 +---- python/sglang/srt/mem_cache/hiradix_cache.py | 6 +- .../sglang/srt/mem_cache/memory_pool_host.py | 112 ++---------------- 3 files changed, 15 insertions(+), 130 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 6c96c80a3..041117753 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -462,7 +462,6 @@ class HiCacheController: host_indices = self.mem_pool_host.alloc(len(device_indices)) if host_indices is None: return None - self.mem_pool_host.protect_write(host_indices) self.write_queue.append( CacheOperation(host_indices, device_indices, node_id, priority) ) @@ -486,7 +485,6 @@ class HiCacheController: self.mem_pool_host.backup_from_device_all_layer( self.mem_pool_device, host_indices, device_indices, self.io_backend ) - self.mem_pool_host.complete_io(op.host_indices) finish_event.record() # NOTE: We must save the host indices and device indices here, # this is because we need to guarantee that these tensors are @@ -510,7 +508,6 @@ class HiCacheController: device_indices = self.mem_pool_device_allocator.alloc(len(host_indices)) if device_indices is None: return None - self.mem_pool_host.protect_load(host_indices) self.load_queue.append( CacheOperation(host_indices, device_indices, node_id, priority) ) @@ -555,7 +552,6 @@ class HiCacheController: self.io_backend, ) producer_event.complete(i) - self.mem_pool_host.complete_io(op.host_indices) # NOTE: We must save the host indices and device indices here, # this is because we need to guarantee that these tensors are # still alive when the load stream is executing. @@ -573,29 +569,16 @@ class HiCacheController: ) return producer_id - def evict_device( - self, device_indices: torch.Tensor, host_indices: torch.Tensor - ) -> int: - if self.mem_pool_host.is_synced(host_indices): - self.mem_pool_device_allocator.free(device_indices) - self.mem_pool_host.update_backup(host_indices) - return len(device_indices) - else: - raise ValueError( - f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" - ) + def evict_device(self, device_indices: torch.Tensor) -> int: + self.mem_pool_device_allocator.free(device_indices) + return len(device_indices) def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: if not backup_only: raise ValueError("Other eviction policies are not supported yet.") - if self.mem_pool_host.is_backup(host_indices): - self.mem_pool_host.free(host_indices) - return len(host_indices) - else: - raise ValueError( - f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" - ) + self.mem_pool_host.free(host_indices) + return len(host_indices) def prefetch( self, diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 9dfe9aca0..f2ed1aea9 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -305,7 +305,7 @@ class HiRadixCache(RadixCache): def _evict_backuped(self, node: TreeNode): # evict a node already written to host - num_evicted = self.cache_controller.evict_device(node.value, node.host_value) + num_evicted = self.cache_controller.evict_device(node.value) assert num_evicted > 0 self.evictable_size_ -= num_evicted node.value = None @@ -576,8 +576,6 @@ class HiRadixCache(RadixCache): written_indices, hash_value[: min_completed_tokens // self.page_size], ) - if len(written_indices): - self.cache_controller.mem_pool_host.update_prefetch(written_indices) self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) self.cache_controller.append_host_mem_release( @@ -775,7 +773,6 @@ class HiRadixCache(RadixCache): # change the reference if the node is evicted # this often happens in the case of KV cache recomputation node.value = value[:prefix_len] - self.token_to_kv_pool_host.update_synced(node.host_value) self.evictable_size_ += len(node.value) else: self._inc_hit_count(node, chunked) @@ -785,7 +782,6 @@ class HiRadixCache(RadixCache): new_node = self._split_node(node.key, node, prefix_len) if new_node.evicted: new_node.value = value[:prefix_len] - self.token_to_kv_pool_host.update_synced(new_node.host_value) self.evictable_size_ += len(new_node.value) else: self._inc_hit_count(new_node, chunked) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index ab7538465..f6d655af0 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu): logger = logging.getLogger(__name__) -class MemoryStateInt(IntEnum): - IDLE = 0 - RESERVED = 1 - PROTECTED = 2 - SYNCED = 3 - BACKUP = 4 +def synchronized(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) - -def synchronized(debug_only=False): - def _decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if (not debug_only) or self.debug: - with self.lock: - return func(self, *args, **kwargs) - else: - return True - - return wrapper - - return _decorator + return wrapper class HostKVCache(abc.ABC): @@ -110,7 +96,6 @@ class HostKVCache(abc.ABC): # A lock for synchronized operations on memory allocation and state transitions. self.lock = threading.RLock() - self.debug = logger.isEnabledFor(logging.DEBUG) self.clear() @abc.abstractmethod @@ -161,7 +146,7 @@ class HostKVCache(abc.ABC): """ raise NotImplementedError() - @synchronized() + @synchronized def clear(self): # Initialize memory states and tracking structures. self.mem_state = torch.zeros( @@ -172,7 +157,7 @@ class HostKVCache(abc.ABC): def available_size(self): return len(self.free_slots) - @synchronized() + @synchronized def alloc(self, need_size: int) -> Optional[torch.Tensor]: assert ( need_size % self.page_size == 0 @@ -183,92 +168,13 @@ class HostKVCache(abc.ABC): select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - if self.debug: - self.mem_state[select_index] = MemoryStateInt.RESERVED - return select_index - @synchronized() + @synchronized def free(self, indices: torch.Tensor) -> int: self.free_slots = torch.cat([self.free_slots, indices]) - if self.debug: - self.mem_state[indices] = MemoryStateInt.IDLE return len(indices) - @synchronized(debug_only=True) - def get_state(self, indices: torch.Tensor) -> MemoryStateInt: - assert len(indices) > 0, "The indices should not be empty" - states = self.mem_state[indices] - assert ( - states == states[0] - ).all(), "The memory slots should have the same state {}".format(states) - return MemoryStateInt(states[0].item()) - - @synchronized(debug_only=True) - def is_reserved(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.RESERVED - - @synchronized(debug_only=True) - def is_protected(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.PROTECTED - - @synchronized(debug_only=True) - def is_synced(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.SYNCED - - @synchronized(debug_only=True) - def is_backup(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.BACKUP - - @synchronized(debug_only=True) - def update_backup(self, indices: torch.Tensor): - if not self.is_synced(indices): - raise ValueError( - f"The host memory slots should be in SYNCED state before turning into BACKUP. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.BACKUP - - @synchronized(debug_only=True) - def update_prefetch(self, indices: torch.Tensor): - if not self.is_reserved(indices): - raise ValueError( - f"The host memory slots should be in RESERVED state before turning into BACKUP. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.BACKUP - - @synchronized(debug_only=True) - def update_synced(self, indices: torch.Tensor): - self.mem_state[indices] = MemoryStateInt.SYNCED - - @synchronized(debug_only=True) - def protect_write(self, indices: torch.Tensor): - if not self.is_reserved(indices): - raise ValueError( - f"The host memory slots should be RESERVED before write operations. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.PROTECTED - - @synchronized(debug_only=True) - def protect_load(self, indices: torch.Tensor): - if not self.is_backup(indices): - raise ValueError( - f"The host memory slots should be in BACKUP state before load operations. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.PROTECTED - - @synchronized(debug_only=True) - def complete_io(self, indices: torch.Tensor): - if not self.is_protected(indices): - raise ValueError( - f"The host memory slots should be PROTECTED during I/O operations. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.SYNCED - class MHATokenToKVPoolHost(HostKVCache): device_pool: MHATokenToKVPool