diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index a94fdec78..9ef860f63 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -219,6 +219,7 @@ class HiCacheController: token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, mem_pool_host: HostKVCache, page_size: int, + tp_group: torch.distributed.ProcessGroup, load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", io_backend: str = "", @@ -244,11 +245,17 @@ class HiCacheController: self.enable_storage = False # todo: move backend initialization to storage backend module if storage_backend is not None: + # create a new communication group for synchronizing storage operations across TP workers + self.tp_world_size = torch.distributed.get_world_size(group=tp_group) + if self.tp_world_size > 1: + group_ranks = torch.distributed.get_process_group_ranks(tp_group) + self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo") + if storage_backend == "file": self.storage_backend = HiCacheFile() self.enable_storage = True # todo: threshold policy for prefetching - self.prefetch_threshold = prefetch_threshold + self.prefetch_threshold = max(prefetch_threshold, self.page_size) else: raise NotImplementedError( f"Unsupported storage backend: {storage_backend}" @@ -568,13 +575,32 @@ class HiCacheController: else: break + if self.tp_world_size > 1: + storage_hit_count_tensor = torch.tensor( + storage_hit_count, dtype=torch.int + ) + torch.distributed.all_reduce( + storage_hit_count_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + storage_hit_count = storage_hit_count_tensor.item() + if storage_hit_count < self.prefetch_threshold: # not to prefetch if not enough benefits self.prefetch_revoke_queue.put(operation.request_id) - else: - operation.hash_value = hash_value logger.debug( - f"Prefetching {len(hash_value)} pages for request {operation.request_id}." + f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." + ) + else: + operation.hash_value = hash_value[ + : (storage_hit_count // self.page_size) + ] + # free the pre-allocated memory for pages that are not hit + self.mem_pool_host.free(operation.host_indices[storage_hit_count:]) + operation.host_indices = operation.host_indices[:storage_hit_count] + logger.debug( + f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}." ) self.prefetch_buffer.put(operation) @@ -611,17 +637,37 @@ class HiCacheController: last_hash = get_hash_str( tokens_to_backup[i : i + self.page_size], last_hash ) - # todo, handle failures in storage backend - self.storage_backend.set( + success = self.storage_backend.set( last_hash, self.mem_pool_host.get_flat_data_page( operation.host_indices[i] ), ) + if not success: + logger.warning(f"Failed to write page {last_hash} to storage.") + break operation.completed_tokens += self.page_size operation.hash_value.append(last_hash) - self.ack_backup_queue.put((operation.id, operation.hash_value)) + min_completed_tokens = operation.completed_tokens + if self.tp_world_size > 1: + completed_tokens_tensor = torch.tensor( + min_completed_tokens, dtype=torch.int + ) + torch.distributed.all_reduce( + completed_tokens_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + min_completed_tokens = completed_tokens_tensor.item() + + self.ack_backup_queue.put( + ( + operation.id, + operation.hash_value[: min_completed_tokens // self.page_size], + min_completed_tokens, + ) + ) except Empty: continue diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 1dfe661ab..45b26d100 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -9,6 +9,12 @@ import torch logger = logging.getLogger(__name__) +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + + def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str: hasher = hashlib.sha256() @@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage): def __init__(self, file_path: str = "/tmp/hicache"): self.file_path = file_path - if not os.path.exists(self.file_path): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else "" + if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") + def _get_suffixed_key(self, key: str) -> str: + return key + self.tp_suffix + def get( self, key: str, target_location: Optional[torch.Tensor] = None ) -> torch.Tensor | None: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: # todo: fixing the target_location logic to enable in-place loading @@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage): ] def set(self, key: str, value: torch.Tensor) -> bool: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") if self.exists(key): logger.debug(f"Key {key} already exists. Skipped.") @@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage): return True def exists(self, key: str) -> bool: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") return os.path.exists(tensor_path) def delete(self, key: str) -> None: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: os.remove(tensor_path) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 796f0553c..05248a1de 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -50,6 +50,7 @@ class HiRadixCache(RadixCache): raise ValueError(f"HiRadixCache only supports MHA and MLA yet") self.tp_group = tp_cache_group + self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.enable_storage = hicache_storage_backend is not None # todo: customizable storage prefetch threshold self.prefetch_threshold = 256 @@ -59,6 +60,7 @@ class HiRadixCache(RadixCache): token_to_kv_pool_allocator, self.token_to_kv_pool_host, page_size, + self.tp_group, load_cache_event=self.load_cache_event, write_policy=hicache_write_policy, io_backend=hicache_io_backend, @@ -153,7 +155,7 @@ class HiRadixCache(RadixCache): queue_size = torch.tensor( self.cache_controller.ack_write_queue.qsize(), dtype=torch.int ) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to radix cache torch.distributed.all_reduce( queue_size, @@ -353,7 +355,7 @@ class HiRadixCache(RadixCache): queue_size = torch.tensor( self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int ) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache torch.distributed.all_reduce( queue_size, @@ -372,7 +374,7 @@ class HiRadixCache(RadixCache): queue_size = torch.tensor( self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int ) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache torch.distributed.all_reduce( queue_size, @@ -380,9 +382,15 @@ class HiRadixCache(RadixCache): group=self.tp_group, ) for _ in range(queue_size.item()): - ack_id, hash_value = self.cache_controller.ack_backup_queue.get() - self.ongoing_backup[ack_id].hash_value = hash_value - self.ongoing_backup[ack_id].release_host() + ack_id, hash_value, completed_tokens = ( + self.cache_controller.ack_backup_queue.get() + ) + host_node = self.ongoing_backup[ack_id] + if completed_tokens < len(host_node.key): + # backup is only partially successful, split the node + new_node = self._split_node(host_node.key, host_node, completed_tokens) + new_node.hash_value = hash_value + host_node.release_host() del self.ongoing_backup[ack_id] def check_prefetch_progress(self, req_id: str): @@ -400,15 +408,18 @@ class HiRadixCache(RadixCache): ) logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") - min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + min_completed_tokens = completed_tokens + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache + completed_tokens_tensor = torch.tensor( + min_completed_tokens, dtype=torch.int + ) torch.distributed.all_reduce( - min_completed_tokens, + completed_tokens_tensor, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) - min_completed_tokens = min_completed_tokens.item() + min_completed_tokens = completed_tokens_tensor.item() fetched_token_ids = token_ids[:min_completed_tokens] written_indices = host_indices[:min_completed_tokens] matched_length = self._insert_helper_host( @@ -465,16 +476,19 @@ class HiRadixCache(RadixCache): new_input_tokens: List[int], last_hash: Optional[str] = None, ): - if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold: + # align the number of fetching tokens to the page size + prefetch_length = len(new_input_tokens) - ( + len(new_input_tokens) % self.page_size + ) + new_input_tokens = new_input_tokens[:prefetch_length] + if not self.enable_storage or prefetch_length < self.prefetch_threshold: return last_host_node.protect_host() - host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens)) + host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) if host_indices is None: - self.evict_host(len(new_input_tokens)) - host_indices = self.cache_controller.mem_pool_host.alloc( - len(new_input_tokens) - ) + self.evict_host(prefetch_length) + host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) if host_indices is None: last_host_node.release_host() # no sufficient host memory to prefetch diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index f50347962..0116e7141 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -126,6 +126,9 @@ class HostKVCache(abc.ABC): @synchronized() def alloc(self, need_size: int) -> torch.Tensor: + assert ( + need_size % self.page_size == 0 + ), "The requested size should be a multiple of the page size." if need_size > self.available_size(): return None