diff --git a/python/sglang/srt/mem_cache/storage/backend_factory.py b/python/sglang/srt/mem_cache/storage/backend_factory.py index 3af598abb..dd5da6a5c 100644 --- a/python/sglang/srt/mem_cache/storage/backend_factory.py +++ b/python/sglang/srt/mem_cache/storage/backend_factory.py @@ -161,7 +161,7 @@ class StorageBackendFactory: if backend_name == "file": return backend_class(storage_config) elif backend_name == "nixl": - return backend_class() + return backend_class(storage_config) elif backend_name == "mooncake": backend = backend_class(storage_config) return backend diff --git a/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py b/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py index 327c90502..55b3dd976 100644 --- a/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +++ b/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from sglang.srt.mem_cache.hicache_storage import HiCacheStorage +from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration @@ -26,7 +26,12 @@ logger = logging.getLogger(__name__) class HiCacheNixl(HiCacheStorage): """HiCacheNixl provides high-performance storage using NIXL plugins.""" - def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"): + def __init__( + self, + storage_config: HiCacheStorageConfig, + file_path: str = "/tmp/hicache_storage", + plugin: str = "auto", + ): """Initialize NIXL storage connector.""" # Might be better to be unified across HiCache backends and moved to HiCacheController file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path) @@ -36,6 +41,19 @@ class HiCacheNixl(HiCacheStorage): else None ) + # Initialize suffix based on storage config + tp_rank, tp_size, model_name, is_mla_model = ( + storage_config.tp_rank, + storage_config.tp_size, + storage_config.model_name, + storage_config.is_mla_model, + ) + model_name = "-".join(model_name.split("/")) if model_name else "" + if is_mla_model: + self.config_suffix = f"_{model_name}" + else: + self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}" + agent_config = nixl_agent_config(backends=[]) self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}" self.agent = nixl_agent(self.agent_name, agent_config) @@ -46,6 +64,9 @@ class HiCacheNixl(HiCacheStorage): self.registration = NixlRegistration(self.agent) + def _get_suffixed_key(self, key: str) -> str: + return key + self.config_suffix + def register_buffers( self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]] ) -> Optional[Any]: @@ -194,11 +215,14 @@ class HiCacheNixl(HiCacheStorage): else: dest = target_locations + # Add suffix to keys + suffixed_keys = [self._get_suffixed_key(key) for key in keys] + if self.backend_selector.mem_type == "FILE": - file_paths = [self.file_manager.get_file_path(key) for key in keys] + file_paths = [self.file_manager.get_file_path(key) for key in suffixed_keys] success = self._execute_transfer(dest, file_paths, "READ") else: - success = self._execute_transfer(dest, keys, "READ") + success = self._execute_transfer(dest, suffixed_keys, "READ") return target_locations if success and not target_sizes else [None] * len(keys) def set( @@ -227,9 +251,12 @@ class HiCacheNixl(HiCacheStorage): if not values: values = list(zip(target_locations, target_sizes)) + # Add suffix to keys + suffixed_keys = [self._get_suffixed_key(key) for key in keys] + if self.backend_selector.mem_type == "FILE": file_paths = [] - for key in keys: + for key in suffixed_keys: file_path = self.file_manager.get_file_path(key) # New file per set, to be updated when partial writes is added to HiCache if not self.file_manager.create_file(file_path): @@ -238,11 +265,14 @@ class HiCacheNixl(HiCacheStorage): file_paths.append(file_path) return self._execute_transfer(values, file_paths, "WRITE") else: # mem_type == "OBJ" - return self._execute_transfer(values, keys, "WRITE") + return self._execute_transfer(values, suffixed_keys, "WRITE") def exists(self, key: str) -> bool: + # Add suffix to key + suffixed_key = self._get_suffixed_key(key) + tuples = self.registration.create_query_tuples( - key, + suffixed_key, self.backend_selector.mem_type, self.file_manager if self.backend_selector.mem_type == "FILE" else None, ) diff --git a/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py b/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py index 951e5a4ea..3784ab91a 100755 --- a/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +++ b/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock import torch +from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl from sglang.srt.mem_cache.storage.nixl.nixl_utils import ( NixlFileManager, @@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase): # Create instances self.file_manager = NixlFileManager(self.test_dir) self.registration = NixlRegistration(self.mock_agent) + + # Create storage config for testing + self.storage_config = HiCacheStorageConfig( + tp_rank=0, + tp_size=2, + is_mla_model=False, + is_page_first_layout=False, + model_name="test_model", + ) + try: - self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX") + self.hicache = HiCacheNixl( + storage_config=self.storage_config, + file_path=self.test_dir, + plugin="POSIX", + ) except ImportError: self.skipTest("NIXL not available, skipping NIXL storage tests")