diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index f2c5ec0fa..bf82dcd15 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage): entries: int, dtype: torch.dtype, metadata_client: Hf3fsMetadataInterface, + is_mla_model: bool = False, ): self.rank = rank self.file_path = file_path @@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage): self.entries = entries self.dtype = dtype self.metadata_client = metadata_client - + self.is_mla_model = is_mla_model self.numel = self.bytes_per_page // self.dtype.itemsize self.num_pages = self.file_size // self.bytes_per_page + self.skip_backup = False + if self.is_mla_model and self.rank != 0: + self.skip_backup = True + self.rank = 0 logger.info( f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: " @@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage): raise ValueError(f"Missing required keys in config: {missing_keys}") # Choose metadata client based on configuration + is_mla_model = False if "metadata_server_url" in config and config["metadata_server_url"]: # Use global metadata client to connect to metadata server metadata_server_url = config["metadata_server_url"] metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url) + + # Enable MLA optimization only when using the global metadata client + is_mla_model = storage_config.is_mla_model if storage_config else False logger.info( f"Using global metadata client with server url: {metadata_server_url}" ) @@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage): return HiCacheHF3FS( rank=rank, - file_path=f"{config['file_path_prefix']}.{rank}.bin", + # Let all ranks use the same file path for MLA model + file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin", file_size=int(config["file_size"]), numjobs=int(config["numjobs"]), bytes_per_page=bytes_per_page, entries=int(config["entries"]), dtype=dtype, metadata_client=metadata_client, + is_mla_model=is_mla_model, ) def get( @@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage): target_locations: Optional[Any] = None, target_sizes: Optional[Any] = None, ) -> bool: + # In MLA backend, only one rank needs to backup the KV cache + if self.skip_backup: + return True + # Todo: Add prefix block's hash key key_with_prefix = [(key, "") for key in keys] indices = self.metadata_client.reserve_and_allocate_page_indices( @@ -363,16 +378,21 @@ class HiCacheHF3FS(HiCacheStorage): return all(results) - @synchronized() def delete(self, key: str) -> None: self.metadata_client.delete_keys(self.rank, [key]) - @synchronized() def exists(self, key: str) -> bool: result = self.metadata_client.exists(self.rank, [key]) return result[0] if result else False - @synchronized() + def batch_exists(self, keys: List[str]) -> int: + results = self.metadata_client.exists(self.rank, keys) + for i in range(len(keys)): + if not results[i]: + return i + + return len(keys) + def clear(self) -> None: self.metadata_client.clear(self.rank)