refactor(hicache): Introduce generic HiCacheStorageConfig for improved configuration management (#9555)
Co-authored-by: Teng Ma <805522925@qq.com>
This commit is contained in:
@@ -57,9 +57,7 @@ def test():
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
||||||
|
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)
|
||||||
rank = 0
|
|
||||||
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank)
|
|
||||||
|
|
||||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
assert numel * dtype.itemsize == bytes_per_page
|
assert numel * dtype.itemsize == bytes_per_page
|
||||||
|
|||||||
@@ -22,11 +22,21 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
is_dp_attention_enabled,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -231,6 +241,8 @@ class HiCacheController:
|
|||||||
io_backend: str = "",
|
io_backend: str = "",
|
||||||
storage_backend: Optional[str] = None,
|
storage_backend: Optional[str] = None,
|
||||||
prefetch_threshold: int = 256,
|
prefetch_threshold: int = 256,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
storage_backend_extra_config: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||||
@@ -248,20 +260,22 @@ class HiCacheController:
|
|||||||
|
|
||||||
self.get_hash_str = get_hash_str
|
self.get_hash_str = get_hash_str
|
||||||
|
|
||||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
self.storage_config = self._generate_storage_config(
|
||||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
model_name, storage_backend_extra_config
|
||||||
|
)
|
||||||
# In MLA backend, only one rank needs to backup the KV cache
|
# In MLA backend, only one rank needs to backup the KV cache
|
||||||
self.backup_skip = (
|
self.backup_skip = (
|
||||||
is_mla_backend
|
self.storage_config.is_mla_model
|
||||||
# todo: for load balancing, decide which rank to backup the KV cache by hash value
|
# todo: for load balancing, decide which rank to backup the KV cache by hash value
|
||||||
and get_tensor_model_parallel_rank() != 0
|
and self.storage_config.tp_rank != 0
|
||||||
# todo: support other storage backends
|
# todo: support other storage backends
|
||||||
and self.storage_backend_type in ["file", "mooncake"]
|
and self.storage_backend_type in ["file", "mooncake"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if storage_backend == "file":
|
if storage_backend == "file":
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
||||||
|
|
||||||
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend)
|
self.storage_backend = HiCacheFile(self.storage_config)
|
||||||
elif storage_backend == "nixl":
|
elif storage_backend == "nixl":
|
||||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||||
|
|
||||||
@@ -271,7 +285,7 @@ class HiCacheController:
|
|||||||
MooncakeStore,
|
MooncakeStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend)
|
self.storage_backend = MooncakeStore(self.storage_config)
|
||||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
assert self.mem_pool_host.layout == "page_first"
|
assert self.mem_pool_host.layout == "page_first"
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
@@ -289,7 +303,7 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
dtype = mem_pool_host.dtype
|
dtype = mem_pool_host.dtype
|
||||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||||
bytes_per_page, dtype
|
bytes_per_page, dtype, self.storage_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -370,6 +384,40 @@ class HiCacheController:
|
|||||||
self.prefetch_thread.start()
|
self.prefetch_thread.start()
|
||||||
self.backup_thread.start()
|
self.backup_thread.start()
|
||||||
|
|
||||||
|
def _generate_storage_config(
|
||||||
|
self,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
storage_backend_extra_config: Optional[str] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if is_dp_attention_enabled():
|
||||||
|
self.tp_rank = get_attention_tp_rank()
|
||||||
|
self.tp_size = get_attention_tp_size()
|
||||||
|
else:
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||||
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||||
|
|
||||||
|
# Parse extra config JSON if provided
|
||||||
|
extra_config = None
|
||||||
|
if storage_backend_extra_config:
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
extra_config = json.loads(storage_backend_extra_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid backend extra config JSON: {e}")
|
||||||
|
|
||||||
|
return HiCacheStorageConfig(
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
is_mla_model=is_mla_backend,
|
||||||
|
model_name=model_name,
|
||||||
|
extra_config=extra_config,
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.stop_event.set()
|
self.stop_event.set()
|
||||||
self.write_thread.join()
|
self.write_thread.join()
|
||||||
|
|||||||
@@ -627,6 +627,8 @@ class Scheduler(
|
|||||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||||
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
||||||
|
model_name=server_args.served_model_name,
|
||||||
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
||||||
)
|
)
|
||||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||||
self.tree_cache.cache_controller.layer_done_counter
|
self.tree_cache.cache_controller.layer_done_counter
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -9,17 +10,6 @@ import torch
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.dp_attention import (
|
|
||||||
get_attention_tp_rank,
|
|
||||||
get_attention_tp_size,
|
|
||||||
is_dp_attention_enabled,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
||||||
hasher = hashlib.sha256()
|
hasher = hashlib.sha256()
|
||||||
|
|
||||||
@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
|||||||
return hasher.hexdigest()
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HiCacheStorageConfig:
|
||||||
|
tp_rank: int
|
||||||
|
tp_size: int
|
||||||
|
is_mla_model: bool
|
||||||
|
model_name: Optional[str]
|
||||||
|
extra_config: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class HiCacheStorage(ABC):
|
class HiCacheStorage(ABC):
|
||||||
"""
|
"""
|
||||||
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
||||||
@@ -117,18 +116,17 @@ class HiCacheStorage(ABC):
|
|||||||
|
|
||||||
class HiCacheFile(HiCacheStorage):
|
class HiCacheFile(HiCacheStorage):
|
||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache", is_mla_backend: bool = False):
|
def __init__(
|
||||||
|
self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
|
||||||
|
):
|
||||||
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
||||||
if is_dp_attention_enabled():
|
|
||||||
tp_rank = get_attention_tp_rank()
|
|
||||||
tp_size = get_attention_tp_size()
|
|
||||||
else:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
self.tp_suffix = (
|
tp_rank, tp_size, is_mla = (
|
||||||
f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla_backend else ""
|
storage_config.tp_rank,
|
||||||
|
storage_config.tp_size,
|
||||||
|
storage_config.is_mla_model,
|
||||||
)
|
)
|
||||||
|
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
||||||
if not os.path.exists(self.file_path) and tp_rank == 0:
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
||||||
os.makedirs(self.file_path)
|
os.makedirs(self.file_path)
|
||||||
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_mem_layout: str,
|
hicache_mem_layout: str,
|
||||||
hicache_storage_backend: Optional[str] = None,
|
hicache_storage_backend: Optional[str] = None,
|
||||||
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
storage_backend_extra_config: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if hicache_io_backend == "direct":
|
if hicache_io_backend == "direct":
|
||||||
@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache):
|
|||||||
io_backend=hicache_io_backend,
|
io_backend=hicache_io_backend,
|
||||||
storage_backend=hicache_storage_backend,
|
storage_backend=hicache_storage_backend,
|
||||||
prefetch_threshold=self.prefetch_threshold,
|
prefetch_threshold=self.prefetch_threshold,
|
||||||
|
model_name=model_name,
|
||||||
|
storage_backend_extra_config=storage_backend_extra_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
|
|||||||
@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||||
from sglang.srt.layers.dp_attention import (
|
|
||||||
get_attention_tp_rank,
|
|
||||||
is_dp_attention_enabled,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -172,19 +167,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_env_config(
|
def from_env_config(
|
||||||
bytes_per_page: int, dtype: torch.dtype, rank: int = None
|
bytes_per_page: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
storage_config: HiCacheStorageConfig = None,
|
||||||
) -> "HiCacheHF3FS":
|
) -> "HiCacheHF3FS":
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||||
Hf3fsGlobalMetadataClient,
|
Hf3fsGlobalMetadataClient,
|
||||||
Hf3fsLocalMetadataClient,
|
Hf3fsLocalMetadataClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
if rank is None:
|
rank = storage_config.tp_rank if storage_config is not None else 0
|
||||||
rank = (
|
|
||||||
get_attention_tp_rank()
|
|
||||||
if is_dp_attention_enabled()
|
|
||||||
else get_tensor_model_parallel_rank()
|
|
||||||
)
|
|
||||||
|
|
||||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||||
if not config_path:
|
if not config_path:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||||
|
|
||||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
||||||
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
||||||
@@ -84,15 +84,7 @@ class MooncakeStoreConfig:
|
|||||||
|
|
||||||
|
|
||||||
class MooncakeStore(HiCacheStorage):
|
class MooncakeStore(HiCacheStorage):
|
||||||
def __init__(self, is_mla_backend: bool = False):
|
def __init__(self, storage_config: HiCacheStorageConfig = None):
|
||||||
"""
|
|
||||||
Initialize MooncakeStore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_mla_backend: If the backend is MLA
|
|
||||||
"""
|
|
||||||
self.is_mla_backend = is_mla_backend
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mooncake.store import MooncakeDistributedStore
|
from mooncake.store import MooncakeDistributedStore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -123,6 +115,13 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
self.warmup()
|
self.warmup()
|
||||||
logger.info("Mooncake store warmup successfully.")
|
logger.info("Mooncake store warmup successfully.")
|
||||||
|
|
||||||
|
if storage_config is not None:
|
||||||
|
self.is_mla_backend = storage_config.is_mla_model
|
||||||
|
self.local_rank = storage_config.tp_rank
|
||||||
|
else:
|
||||||
|
self.is_mla_backend = False
|
||||||
|
self.local_rank = 0
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error("Configuration loading failed: %s", e)
|
logger.error("Configuration loading failed: %s", e)
|
||||||
raise
|
raise
|
||||||
@@ -130,8 +129,6 @@ class MooncakeStore(HiCacheStorage):
|
|||||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.local_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
def warmup(self):
|
def warmup(self):
|
||||||
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
||||||
warmup_value = bytes(4 * 1024) # 4 KB
|
warmup_value = bytes(4 * 1024) # 4 KB
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class ServerArgs:
|
|||||||
hicache_mem_layout: str = "layer_first"
|
hicache_mem_layout: str = "layer_first"
|
||||||
hicache_storage_backend: Optional[str] = None
|
hicache_storage_backend: Optional[str] = None
|
||||||
hicache_storage_prefetch_policy: str = "best_effort"
|
hicache_storage_prefetch_policy: str = "best_effort"
|
||||||
|
hicache_storage_backend_extra_config: Optional[str] = None
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
enable_double_sparsity: bool = False
|
enable_double_sparsity: bool = False
|
||||||
@@ -1641,6 +1642,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_storage_prefetch_policy,
|
default=ServerArgs.hicache_storage_prefetch_policy,
|
||||||
help="Control when prefetching from the storage backend should stop.",
|
help="Control when prefetching from the storage backend should stop.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hicache-storage-backend-extra-config",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.hicache_storage_backend_extra_config,
|
||||||
|
help="A dictionary in JSON string format containing extra configuration for the storage backend.",
|
||||||
|
)
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user