fix file and object naming scheme in HiCacheNixl to avoid data corruption (#10969)
Signed-off-by: Zirui Liu <ziliu@ddn.com>
This commit is contained in:
@@ -161,7 +161,7 @@ class StorageBackendFactory:
|
||||
if backend_name == "file":
|
||||
return backend_class(storage_config)
|
||||
elif backend_name == "nixl":
|
||||
return backend_class()
|
||||
return backend_class(storage_config)
|
||||
elif backend_name == "mooncake":
|
||||
backend = backend_class(storage_config)
|
||||
return backend
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||
|
||||
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
||||
|
||||
@@ -26,7 +26,12 @@ logger = logging.getLogger(__name__)
|
||||
class HiCacheNixl(HiCacheStorage):
|
||||
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
||||
|
||||
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
||||
def __init__(
|
||||
self,
|
||||
storage_config: HiCacheStorageConfig,
|
||||
file_path: str = "/tmp/hicache_storage",
|
||||
plugin: str = "auto",
|
||||
):
|
||||
"""Initialize NIXL storage connector."""
|
||||
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
||||
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
||||
@@ -36,6 +41,19 @@ class HiCacheNixl(HiCacheStorage):
|
||||
else None
|
||||
)
|
||||
|
||||
# Initialize suffix based on storage config
|
||||
tp_rank, tp_size, model_name, is_mla_model = (
|
||||
storage_config.tp_rank,
|
||||
storage_config.tp_size,
|
||||
storage_config.model_name,
|
||||
storage_config.is_mla_model,
|
||||
)
|
||||
model_name = "-".join(model_name.split("/")) if model_name else ""
|
||||
if is_mla_model:
|
||||
self.config_suffix = f"_{model_name}"
|
||||
else:
|
||||
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
||||
|
||||
agent_config = nixl_agent_config(backends=[])
|
||||
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
||||
self.agent = nixl_agent(self.agent_name, agent_config)
|
||||
@@ -46,6 +64,9 @@ class HiCacheNixl(HiCacheStorage):
|
||||
|
||||
self.registration = NixlRegistration(self.agent)
|
||||
|
||||
def _get_suffixed_key(self, key: str) -> str:
|
||||
return key + self.config_suffix
|
||||
|
||||
def register_buffers(
|
||||
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
||||
) -> Optional[Any]:
|
||||
@@ -194,11 +215,14 @@ class HiCacheNixl(HiCacheStorage):
|
||||
else:
|
||||
dest = target_locations
|
||||
|
||||
# Add suffix to keys
|
||||
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
||||
file_paths = [self.file_manager.get_file_path(key) for key in suffixed_keys]
|
||||
success = self._execute_transfer(dest, file_paths, "READ")
|
||||
else:
|
||||
success = self._execute_transfer(dest, keys, "READ")
|
||||
success = self._execute_transfer(dest, suffixed_keys, "READ")
|
||||
return target_locations if success and not target_sizes else [None] * len(keys)
|
||||
|
||||
def set(
|
||||
@@ -227,9 +251,12 @@ class HiCacheNixl(HiCacheStorage):
|
||||
if not values:
|
||||
values = list(zip(target_locations, target_sizes))
|
||||
|
||||
# Add suffix to keys
|
||||
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = []
|
||||
for key in keys:
|
||||
for key in suffixed_keys:
|
||||
file_path = self.file_manager.get_file_path(key)
|
||||
# New file per set, to be updated when partial writes is added to HiCache
|
||||
if not self.file_manager.create_file(file_path):
|
||||
@@ -238,11 +265,14 @@ class HiCacheNixl(HiCacheStorage):
|
||||
file_paths.append(file_path)
|
||||
return self._execute_transfer(values, file_paths, "WRITE")
|
||||
else: # mem_type == "OBJ"
|
||||
return self._execute_transfer(values, keys, "WRITE")
|
||||
return self._execute_transfer(values, suffixed_keys, "WRITE")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
# Add suffix to key
|
||||
suffixed_key = self._get_suffixed_key(key)
|
||||
|
||||
tuples = self.registration.create_query_tuples(
|
||||
key,
|
||||
suffixed_key,
|
||||
self.backend_selector.mem_type,
|
||||
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
||||
NixlFileManager,
|
||||
@@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase):
|
||||
# Create instances
|
||||
self.file_manager = NixlFileManager(self.test_dir)
|
||||
self.registration = NixlRegistration(self.mock_agent)
|
||||
|
||||
# Create storage config for testing
|
||||
self.storage_config = HiCacheStorageConfig(
|
||||
tp_rank=0,
|
||||
tp_size=2,
|
||||
is_mla_model=False,
|
||||
is_page_first_layout=False,
|
||||
model_name="test_model",
|
||||
)
|
||||
|
||||
try:
|
||||
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
|
||||
self.hicache = HiCacheNixl(
|
||||
storage_config=self.storage_config,
|
||||
file_path=self.test_dir,
|
||||
plugin="POSIX",
|
||||
)
|
||||
except ImportError:
|
||||
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user