refactor(hicache): Introduce generic HiCacheStorageConfig for improved configuration management (#9555)
Co-authored-by: Teng Ma <805522925@qq.com>
This commit is contained in:
@@ -22,11 +22,21 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -231,6 +241,8 @@ class HiCacheController:
|
||||
io_backend: str = "",
|
||||
storage_backend: Optional[str] = None,
|
||||
prefetch_threshold: int = 256,
|
||||
model_name: Optional[str] = None,
|
||||
storage_backend_extra_config: Optional[str] = None,
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
@@ -248,20 +260,22 @@ class HiCacheController:
|
||||
|
||||
self.get_hash_str = get_hash_str
|
||||
|
||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||
self.storage_config = self._generate_storage_config(
|
||||
model_name, storage_backend_extra_config
|
||||
)
|
||||
# In MLA backend, only one rank needs to backup the KV cache
|
||||
self.backup_skip = (
|
||||
is_mla_backend
|
||||
self.storage_config.is_mla_model
|
||||
# todo: for load balancing, decide which rank to backup the KV cache by hash value
|
||||
and get_tensor_model_parallel_rank() != 0
|
||||
and self.storage_config.tp_rank != 0
|
||||
# todo: support other storage backends
|
||||
and self.storage_backend_type in ["file", "mooncake"]
|
||||
)
|
||||
|
||||
if storage_backend == "file":
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
||||
|
||||
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend)
|
||||
self.storage_backend = HiCacheFile(self.storage_config)
|
||||
elif storage_backend == "nixl":
|
||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||
|
||||
@@ -271,7 +285,7 @@ class HiCacheController:
|
||||
MooncakeStore,
|
||||
)
|
||||
|
||||
self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend)
|
||||
self.storage_backend = MooncakeStore(self.storage_config)
|
||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||
assert self.mem_pool_host.layout == "page_first"
|
||||
elif storage_backend == "hf3fs":
|
||||
@@ -289,7 +303,7 @@ class HiCacheController:
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
bytes_per_page, dtype
|
||||
bytes_per_page, dtype, self.storage_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -370,6 +384,40 @@ class HiCacheController:
|
||||
self.prefetch_thread.start()
|
||||
self.backup_thread.start()
|
||||
|
||||
def _generate_storage_config(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
storage_backend_extra_config: Optional[str] = None,
|
||||
):
|
||||
|
||||
if is_dp_attention_enabled():
|
||||
self.tp_rank = get_attention_tp_rank()
|
||||
self.tp_size = get_attention_tp_size()
|
||||
else:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
||||
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
||||
|
||||
# Parse extra config JSON if provided
|
||||
extra_config = None
|
||||
if storage_backend_extra_config:
|
||||
try:
|
||||
import json
|
||||
|
||||
extra_config = json.loads(storage_backend_extra_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid backend extra config JSON: {e}")
|
||||
|
||||
return HiCacheStorageConfig(
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
is_mla_model=is_mla_backend,
|
||||
model_name=model_name,
|
||||
extra_config=extra_config,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
Reference in New Issue
Block a user