From 83871aa12df1fcd478ab341c76b1cf043facab1d Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Sat, 23 Aug 2025 17:08:32 +0800 Subject: [PATCH] feat(hicache): Supports 3fs-hicache compatibility with dp-attention (#9372) --- benchmark/hf3fs/bench_storage.py | 2 +- python/sglang/srt/managers/cache_controller.py | 3 +-- python/sglang/srt/mem_cache/hicache_storage.py | 14 ++++++++++++-- .../srt/mem_cache/storage/hf3fs/storage_hf3fs.py | 14 +++++++++++++- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/benchmark/hf3fs/bench_storage.py b/benchmark/hf3fs/bench_storage.py index 30702b635..c3f514e0e 100644 --- a/benchmark/hf3fs/bench_storage.py +++ b/benchmark/hf3fs/bench_storage.py @@ -59,7 +59,7 @@ def test(): raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}") rank = 0 - hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype) + hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank) numel = 2 * tokens_per_page * layer_num * head_num * head_dim assert numel * dtype.itemsize == bytes_per_page diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 8fa8ab00c..e031c3ada 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -269,7 +269,6 @@ class HiCacheController: HiCacheHF3FS, ) - rank = get_tensor_model_parallel_rank() if self.mem_pool_host.layout == "page_first": bytes_per_page = ( mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size @@ -280,7 +279,7 @@ class HiCacheController: ) dtype = mem_pool_host.dtype self.storage_backend = HiCacheHF3FS.from_env_config( - rank, bytes_per_page, dtype + bytes_per_page, dtype ) self.get_hash_str = get_hash_str else: diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index ed5908bd9..a391b8acc 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -13,6 +13,11 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: @@ -103,8 +108,13 @@ class HiCacheFile(HiCacheStorage): def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False): self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() + if is_dp_attention_enabled(): + tp_rank = get_attention_tp_rank() + tp_size = get_attention_tp_size() + else: + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index b301ee0c8..f5d5a5344 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -11,6 +11,11 @@ from typing import Any, List, Optional, Tuple import torch +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + is_dp_attention_enabled, +) from sglang.srt.mem_cache.hicache_storage import HiCacheStorage from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient @@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage): @staticmethod def from_env_config( - rank: int, bytes_per_page: int, dtype: torch.dtype + bytes_per_page: int, dtype: torch.dtype, rank: int = None ) -> "HiCacheHF3FS": from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( Hf3fsGlobalMetadataClient, Hf3fsLocalMetadataClient, ) + if rank is None: + rank = ( + get_attention_tp_rank() + if is_dp_attention_enabled() + else get_tensor_model_parallel_rank() + ) + config_path = os.getenv(HiCacheHF3FS.default_env_var) if not config_path: return HiCacheHF3FS(