feat(hicache): Supports 3fs-hicache compatibility with dp-attention (#9372)
This commit is contained in:
@@ -269,7 +269,6 @@ class HiCacheController:
|
||||
HiCacheHF3FS,
|
||||
)
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||
@@ -280,7 +279,7 @@ class HiCacheController:
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
rank, bytes_per_page, dtype
|
||||
bytes_per_page, dtype
|
||||
)
|
||||
self.get_hash_str = get_hash_str
|
||||
else:
|
||||
|
||||
@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
|
||||
|
||||
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
||||
@@ -103,8 +108,13 @@ class HiCacheFile(HiCacheStorage):
|
||||
|
||||
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
|
||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if is_dp_attention_enabled():
|
||||
tp_rank = get_attention_tp_rank()
|
||||
tp_size = get_attention_tp_size()
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||
os.makedirs(self.file_path)
|
||||
|
||||
@@ -11,6 +11,11 @@ from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||
|
||||
@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
@staticmethod
|
||||
def from_env_config(
|
||||
rank: int, bytes_per_page: int, dtype: torch.dtype
|
||||
bytes_per_page: int, dtype: torch.dtype, rank: int = None
|
||||
) -> "HiCacheHF3FS":
|
||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||
Hf3fsGlobalMetadataClient,
|
||||
Hf3fsLocalMetadataClient,
|
||||
)
|
||||
|
||||
if rank is None:
|
||||
rank = (
|
||||
get_attention_tp_rank()
|
||||
if is_dp_attention_enabled()
|
||||
else get_tensor_model_parallel_rank()
|
||||
)
|
||||
|
||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||
if not config_path:
|
||||
return HiCacheHF3FS(
|
||||
|
||||
Reference in New Issue
Block a user