feat(hicache): Supports 3fs-hicache compatibility with dp-attention (#9372)
This commit is contained in:
@@ -59,7 +59,7 @@ def test():
|
|||||||
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
||||||
|
|
||||||
rank = 0
|
rank = 0
|
||||||
hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype)
|
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank)
|
||||||
|
|
||||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
assert numel * dtype.itemsize == bytes_per_page
|
assert numel * dtype.itemsize == bytes_per_page
|
||||||
|
|||||||
@@ -269,7 +269,6 @@ class HiCacheController:
|
|||||||
HiCacheHF3FS,
|
HiCacheHF3FS,
|
||||||
)
|
)
|
||||||
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
if self.mem_pool_host.layout == "page_first":
|
if self.mem_pool_host.layout == "page_first":
|
||||||
bytes_per_page = (
|
bytes_per_page = (
|
||||||
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||||
@@ -280,7 +279,7 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
dtype = mem_pool_host.dtype
|
dtype = mem_pool_host.dtype
|
||||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||||
rank, bytes_per_page, dtype
|
bytes_per_page, dtype
|
||||||
)
|
)
|
||||||
self.get_hash_str = get_hash_str
|
self.get_hash_str = get_hash_str
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
is_dp_attention_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
||||||
@@ -103,8 +108,13 @@ class HiCacheFile(HiCacheStorage):
|
|||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
|
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
|
||||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
if is_dp_attention_enabled():
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_rank = get_attention_tp_rank()
|
||||||
|
tp_size = get_attention_tp_size()
|
||||||
|
else:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
||||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||||
os.makedirs(self.file_path)
|
os.makedirs(self.file_path)
|
||||||
|
|||||||
@@ -11,6 +11,11 @@ from typing import Any, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_rank,
|
||||||
|
is_dp_attention_enabled,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||||
|
|
||||||
@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_env_config(
|
def from_env_config(
|
||||||
rank: int, bytes_per_page: int, dtype: torch.dtype
|
bytes_per_page: int, dtype: torch.dtype, rank: int = None
|
||||||
) -> "HiCacheHF3FS":
|
) -> "HiCacheHF3FS":
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||||
Hf3fsGlobalMetadataClient,
|
Hf3fsGlobalMetadataClient,
|
||||||
Hf3fsLocalMetadataClient,
|
Hf3fsLocalMetadataClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if rank is None:
|
||||||
|
rank = (
|
||||||
|
get_attention_tp_rank()
|
||||||
|
if is_dp_attention_enabled()
|
||||||
|
else get_tensor_model_parallel_rank()
|
||||||
|
)
|
||||||
|
|
||||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||||
if not config_path:
|
if not config_path:
|
||||||
return HiCacheHF3FS(
|
return HiCacheHF3FS(
|
||||||
|
|||||||
Reference in New Issue
Block a user