diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 6ba9571b5..8fa8ab00c 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool_host import HostKVCache +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost logger = logging.getLogger(__name__) @@ -238,13 +240,14 @@ class HiCacheController: self.io_backend = io_backend self.enable_storage = False + self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost) # todo: move backend initialization to storage backend module if storage_backend is not None: self.storage_backend_type = storage_backend from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str if storage_backend == "file": - self.storage_backend = HiCacheFile() + self.storage_backend = HiCacheFile(is_mla=self.is_mla) self.get_hash_str = get_hash_str elif storage_backend == "nixl": from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl @@ -257,12 +260,11 @@ class HiCacheController: get_hash_str_mooncake, ) - self.storage_backend = MooncakeStore() + self.storage_backend = MooncakeStore(is_mla=self.is_mla) self.get_hash_str = get_hash_str_mooncake self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) assert self.mem_pool_host.layout == "page_first" elif storage_backend == "hf3fs": - from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( HiCacheHF3FS, ) @@ -399,6 +401,15 @@ class HiCacheController: self.prefetch_thread.start() self.backup_thread.start() + @property + def backup_skip(self): + return ( + self.is_mla + and get_tensor_model_parallel_rank() != 0 + # todo: only support file and mooncake + and self.storage_backend_type in ["file", "mooncake"] + ) + def write( self, device_indices: torch.Tensor, @@ -809,17 +820,20 @@ class HiCacheController: if operation is None: continue - if self.is_mooncake_backend(): - self.mooncake_page_backup(operation) - elif self.storage_backend_type == "hf3fs": - if self.mem_pool_host.layout == "page_first": - self.zerocopy_page_backup(operation, batch_size=128) - elif self.mem_pool_host.layout == "layer_first": - self.generic_page_backup(operation, batch_size=128) + if not self.backup_skip: + if self.is_mooncake_backend(): + self.mooncake_page_backup(operation) + elif self.storage_backend_type == "hf3fs": + if self.mem_pool_host.layout == "page_first": + self.zerocopy_page_backup(operation, batch_size=128) + elif self.mem_pool_host.layout == "layer_first": + self.generic_page_backup(operation, batch_size=128) + else: + self.generic_page_backup(operation) + min_completed_tokens = operation.completed_tokens else: - self.generic_page_backup(operation) + min_completed_tokens = len(operation.token_ids) - min_completed_tokens = operation.completed_tokens if self.tp_world_size > 1: completed_tokens_tensor = torch.tensor( min_completed_tokens, dtype=torch.int diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 90a468cc3..ed5908bd9 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -101,11 +101,11 @@ class HiCacheStorage(ABC): class HiCacheFile(HiCacheStorage): - def __init__(self, file_path: str = "/tmp/hicache"): + def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False): self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else "" + self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 4abc6dc0a..a2cc5bd37 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -7,6 +7,7 @@ from functools import wraps import psutil import torch +from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.utils import is_npu @@ -487,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache): ptr_list.append(k_ptr) ptr_list.append(v_ptr) key_ = keys[index // self.page_size] - key_list.append(f"{key_}_k") - key_list.append(f"{key_}_v") + key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k") + key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v") element_size = ( self.layer_num * self.dtype.itemsize diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 51b47335e..1cddd0092 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -19,14 +19,13 @@ logger = logging.getLogger(__name__) def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None): - local_rank = get_tensor_model_parallel_rank() prefix_str = "" if prior_hash: prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest() current_token_ids_bytes = np.array(token_ids).tobytes() current_hash_object = hashlib.sha256(current_token_ids_bytes) current_hash_hex = current_hash_object.hexdigest() - return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}" + return f"{prefix_str}_{int(current_hash_hex[:16], 16)}" @dataclass @@ -97,7 +96,7 @@ class MooncakeStoreConfig: class MooncakeStore(HiCacheStorage): - def __init__(self): + def __init__(self, is_mla: bool = False): try: from mooncake.store import MooncakeDistributedStore except ImportError as e: @@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage): logger.info("Connect to Mooncake store successfully.") self.warmup() logger.info("Mooncake store warmup successfully.") + self.is_mla = is_mla except ValueError as e: logger.error("Configuration loading failed: %s", e) @@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage): def exists(self, keys) -> bool | dict: _keys = [] + local_rank = get_tensor_model_parallel_rank() for key in keys: if key is None: return None - _keys.append(f"{key}_k") + if self.is_mla: + _keys.append(f"{key}_k") + else: + _keys.append(f"{key}_{local_rank}_k") result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} return result