diff --git a/benchmark/hf3fs/bench_storage.py b/benchmark/hf3fs/bench_storage.py index c3f514e0e..f0ce171bf 100644 --- a/benchmark/hf3fs/bench_storage.py +++ b/benchmark/hf3fs/bench_storage.py @@ -57,9 +57,7 @@ def test(): ) except Exception as e: raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}") - - rank = 0 - hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank) + hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype) numel = 2 * tokens_per_page * layer_num * head_num * head_dim assert numel * dtype.itemsize == bytes_per_page diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index bcd7940ac..d05433339 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -22,11 +22,21 @@ from typing import TYPE_CHECKING, List, Optional import torch +from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig + if TYPE_CHECKING: from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool_host import HostKVCache -from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool logger = logging.getLogger(__name__) @@ -231,6 +241,8 @@ class HiCacheController: io_backend: str = "", storage_backend: Optional[str] = None, prefetch_threshold: int = 256, + model_name: Optional[str] = None, + storage_backend_extra_config: Optional[str] = None, ): self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() @@ -248,20 +260,22 @@ class HiCacheController: self.get_hash_str = get_hash_str - # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. - is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) + self.storage_config = self._generate_storage_config( + model_name, storage_backend_extra_config + ) # In MLA backend, only one rank needs to backup the KV cache self.backup_skip = ( - is_mla_backend + self.storage_config.is_mla_model # todo: for load balancing, decide which rank to backup the KV cache by hash value - and get_tensor_model_parallel_rank() != 0 + and self.storage_config.tp_rank != 0 # todo: support other storage backends and self.storage_backend_type in ["file", "mooncake"] ) + if storage_backend == "file": from sglang.srt.mem_cache.hicache_storage import HiCacheFile - self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend) + self.storage_backend = HiCacheFile(self.storage_config) elif storage_backend == "nixl": from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl @@ -271,7 +285,7 @@ class HiCacheController: MooncakeStore, ) - self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend) + self.storage_backend = MooncakeStore(self.storage_config) self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) assert self.mem_pool_host.layout == "page_first" elif storage_backend == "hf3fs": @@ -289,7 +303,7 @@ class HiCacheController: ) dtype = mem_pool_host.dtype self.storage_backend = HiCacheHF3FS.from_env_config( - bytes_per_page, dtype + bytes_per_page, dtype, self.storage_config ) else: raise NotImplementedError( @@ -370,6 +384,40 @@ class HiCacheController: self.prefetch_thread.start() self.backup_thread.start() + def _generate_storage_config( + self, + model_name: Optional[str] = None, + storage_backend_extra_config: Optional[str] = None, + ): + + if is_dp_attention_enabled(): + self.tp_rank = get_attention_tp_rank() + self.tp_size = get_attention_tp_size() + else: + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. + is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) + + # Parse extra config JSON if provided + extra_config = None + if storage_backend_extra_config: + try: + import json + + extra_config = json.loads(storage_backend_extra_config) + except Exception as e: + logger.error(f"Invalid backend extra config JSON: {e}") + + return HiCacheStorageConfig( + tp_rank=self.tp_rank, + tp_size=self.tp_size, + is_mla_model=is_mla_backend, + model_name=model_name, + extra_config=extra_config, + ) + def reset(self): self.stop_event.set() self.write_thread.join() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f897a5dd4..1feb7c0dd 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -627,6 +627,8 @@ class Scheduler( hicache_mem_layout=server_args.hicache_mem_layout, hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, + model_name=server_args.served_model_name, + storage_backend_extra_config=server_args.hicache_storage_backend_extra_config, ) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 907d1b4b8..c142a59bd 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -2,6 +2,7 @@ import hashlib import logging import os from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, List, Optional import torch @@ -9,17 +10,6 @@ import torch logger = logging.getLogger(__name__) -from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from sglang.srt.layers.dp_attention import ( - get_attention_tp_rank, - get_attention_tp_size, - is_dp_attention_enabled, -) - - def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: hasher = hashlib.sha256() @@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: return hasher.hexdigest() +@dataclass +class HiCacheStorageConfig: + tp_rank: int + tp_size: int + is_mla_model: bool + model_name: Optional[str] + extra_config: Optional[dict] = None + + class HiCacheStorage(ABC): """ HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. @@ -117,18 +116,17 @@ class HiCacheStorage(ABC): class HiCacheFile(HiCacheStorage): - def __init__(self, file_path: str = "/tmp/hicache", is_mla_backend: bool = False): + def __init__( + self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache" + ): self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) - if is_dp_attention_enabled(): - tp_rank = get_attention_tp_rank() - tp_size = get_attention_tp_size() - else: - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - self.tp_suffix = ( - f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla_backend else "" + tp_rank, tp_size, is_mla = ( + storage_config.tp_rank, + storage_config.tp_size, + storage_config.is_mla_model, ) + self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 0df7fb537..c0bd0a3f8 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -39,6 +39,8 @@ class HiRadixCache(RadixCache): hicache_mem_layout: str, hicache_storage_backend: Optional[str] = None, hicache_storage_prefetch_policy: Optional[str] = "best_effort", + model_name: Optional[str] = None, + storage_backend_extra_config: Optional[str] = None, ): if hicache_io_backend == "direct": @@ -87,6 +89,8 @@ class HiRadixCache(RadixCache): io_backend=hicache_io_backend, storage_backend=hicache_storage_backend, prefetch_threshold=self.prefetch_threshold, + model_name=model_name, + storage_backend_extra_config=storage_backend_extra_config, ) # record the nodes with ongoing write through diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index f5d5a5344..f2c5ec0fa 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple import torch -from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.layers.dp_attention import ( - get_attention_tp_rank, - is_dp_attention_enabled, -) -from sglang.srt.mem_cache.hicache_storage import HiCacheStorage +from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient logger = logging.getLogger(__name__) @@ -172,19 +167,16 @@ class HiCacheHF3FS(HiCacheStorage): @staticmethod def from_env_config( - bytes_per_page: int, dtype: torch.dtype, rank: int = None + bytes_per_page: int, + dtype: torch.dtype, + storage_config: HiCacheStorageConfig = None, ) -> "HiCacheHF3FS": from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( Hf3fsGlobalMetadataClient, Hf3fsLocalMetadataClient, ) - if rank is None: - rank = ( - get_attention_tp_rank() - if is_dp_attention_enabled() - else get_tensor_model_parallel_rank() - ) + rank = storage_config.tp_rank if storage_config is not None else 0 config_path = os.getenv(HiCacheHF3FS.default_env_var) if not config_path: diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 704f6787e..a82a2a413 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -10,7 +10,7 @@ import numpy as np import torch from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.mem_cache.hicache_storage import HiCacheStorage +from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB @@ -84,15 +84,7 @@ class MooncakeStoreConfig: class MooncakeStore(HiCacheStorage): - def __init__(self, is_mla_backend: bool = False): - """ - Initialize MooncakeStore. - - Args: - is_mla_backend: If the backend is MLA - """ - self.is_mla_backend = is_mla_backend - + def __init__(self, storage_config: HiCacheStorageConfig = None): try: from mooncake.store import MooncakeDistributedStore except ImportError as e: @@ -123,6 +115,13 @@ class MooncakeStore(HiCacheStorage): self.warmup() logger.info("Mooncake store warmup successfully.") + if storage_config is not None: + self.is_mla_backend = storage_config.is_mla_model + self.local_rank = storage_config.tp_rank + else: + self.is_mla_backend = False + self.local_rank = 0 + except ValueError as e: logger.error("Configuration loading failed: %s", e) raise @@ -130,8 +129,6 @@ class MooncakeStore(HiCacheStorage): logger.error("An error occurred while loading the configuration: %s", exc) raise - self.local_rank = get_tensor_model_parallel_rank() - def warmup(self): warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex warmup_value = bytes(4 * 1024) # 4 KB diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b5c846b94..aa973dec1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -216,6 +216,7 @@ class ServerArgs: hicache_mem_layout: str = "layer_first" hicache_storage_backend: Optional[str] = None hicache_storage_prefetch_policy: str = "best_effort" + hicache_storage_backend_extra_config: Optional[str] = None # Double Sparsity enable_double_sparsity: bool = False @@ -1641,6 +1642,12 @@ class ServerArgs: default=ServerArgs.hicache_storage_prefetch_policy, help="Control when prefetching from the storage backend should stop.", ) + parser.add_argument( + "--hicache-storage-backend-extra-config", + type=str, + default=ServerArgs.hicache_storage_backend_extra_config, + help="A dictionary in JSON string format containing extra configuration for the storage backend.", + ) # Double Sparsity parser.add_argument(