fix file and object naming scheme in HiCacheNixl to avoid data corruption (#10969)
Signed-off-by: Zirui Liu <ziliu@ddn.com>
This commit is contained in:
@@ -161,7 +161,7 @@ class StorageBackendFactory:
|
|||||||
if backend_name == "file":
|
if backend_name == "file":
|
||||||
return backend_class(storage_config)
|
return backend_class(storage_config)
|
||||||
elif backend_name == "nixl":
|
elif backend_name == "nixl":
|
||||||
return backend_class()
|
return backend_class(storage_config)
|
||||||
elif backend_name == "mooncake":
|
elif backend_name == "mooncake":
|
||||||
backend = backend_class(storage_config)
|
backend = backend_class(storage_config)
|
||||||
return backend
|
return backend
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||||
|
|
||||||
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
||||||
|
|
||||||
@@ -26,7 +26,12 @@ logger = logging.getLogger(__name__)
|
|||||||
class HiCacheNixl(HiCacheStorage):
|
class HiCacheNixl(HiCacheStorage):
|
||||||
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_config: HiCacheStorageConfig,
|
||||||
|
file_path: str = "/tmp/hicache_storage",
|
||||||
|
plugin: str = "auto",
|
||||||
|
):
|
||||||
"""Initialize NIXL storage connector."""
|
"""Initialize NIXL storage connector."""
|
||||||
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
||||||
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
||||||
@@ -36,6 +41,19 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize suffix based on storage config
|
||||||
|
tp_rank, tp_size, model_name, is_mla_model = (
|
||||||
|
storage_config.tp_rank,
|
||||||
|
storage_config.tp_size,
|
||||||
|
storage_config.model_name,
|
||||||
|
storage_config.is_mla_model,
|
||||||
|
)
|
||||||
|
model_name = "-".join(model_name.split("/")) if model_name else ""
|
||||||
|
if is_mla_model:
|
||||||
|
self.config_suffix = f"_{model_name}"
|
||||||
|
else:
|
||||||
|
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
||||||
|
|
||||||
agent_config = nixl_agent_config(backends=[])
|
agent_config = nixl_agent_config(backends=[])
|
||||||
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
||||||
self.agent = nixl_agent(self.agent_name, agent_config)
|
self.agent = nixl_agent(self.agent_name, agent_config)
|
||||||
@@ -46,6 +64,9 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
|
|
||||||
self.registration = NixlRegistration(self.agent)
|
self.registration = NixlRegistration(self.agent)
|
||||||
|
|
||||||
|
def _get_suffixed_key(self, key: str) -> str:
|
||||||
|
return key + self.config_suffix
|
||||||
|
|
||||||
def register_buffers(
|
def register_buffers(
|
||||||
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
||||||
) -> Optional[Any]:
|
) -> Optional[Any]:
|
||||||
@@ -194,11 +215,14 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
else:
|
else:
|
||||||
dest = target_locations
|
dest = target_locations
|
||||||
|
|
||||||
|
# Add suffix to keys
|
||||||
|
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
|
||||||
|
|
||||||
if self.backend_selector.mem_type == "FILE":
|
if self.backend_selector.mem_type == "FILE":
|
||||||
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
file_paths = [self.file_manager.get_file_path(key) for key in suffixed_keys]
|
||||||
success = self._execute_transfer(dest, file_paths, "READ")
|
success = self._execute_transfer(dest, file_paths, "READ")
|
||||||
else:
|
else:
|
||||||
success = self._execute_transfer(dest, keys, "READ")
|
success = self._execute_transfer(dest, suffixed_keys, "READ")
|
||||||
return target_locations if success and not target_sizes else [None] * len(keys)
|
return target_locations if success and not target_sizes else [None] * len(keys)
|
||||||
|
|
||||||
def set(
|
def set(
|
||||||
@@ -227,9 +251,12 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
if not values:
|
if not values:
|
||||||
values = list(zip(target_locations, target_sizes))
|
values = list(zip(target_locations, target_sizes))
|
||||||
|
|
||||||
|
# Add suffix to keys
|
||||||
|
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
|
||||||
|
|
||||||
if self.backend_selector.mem_type == "FILE":
|
if self.backend_selector.mem_type == "FILE":
|
||||||
file_paths = []
|
file_paths = []
|
||||||
for key in keys:
|
for key in suffixed_keys:
|
||||||
file_path = self.file_manager.get_file_path(key)
|
file_path = self.file_manager.get_file_path(key)
|
||||||
# New file per set, to be updated when partial writes is added to HiCache
|
# New file per set, to be updated when partial writes is added to HiCache
|
||||||
if not self.file_manager.create_file(file_path):
|
if not self.file_manager.create_file(file_path):
|
||||||
@@ -238,11 +265,14 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
return self._execute_transfer(values, file_paths, "WRITE")
|
return self._execute_transfer(values, file_paths, "WRITE")
|
||||||
else: # mem_type == "OBJ"
|
else: # mem_type == "OBJ"
|
||||||
return self._execute_transfer(values, keys, "WRITE")
|
return self._execute_transfer(values, suffixed_keys, "WRITE")
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
|
# Add suffix to key
|
||||||
|
suffixed_key = self._get_suffixed_key(key)
|
||||||
|
|
||||||
tuples = self.registration.create_query_tuples(
|
tuples = self.registration.create_query_tuples(
|
||||||
key,
|
suffixed_key,
|
||||||
self.backend_selector.mem_type,
|
self.backend_selector.mem_type,
|
||||||
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||||
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
||||||
NixlFileManager,
|
NixlFileManager,
|
||||||
@@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase):
|
|||||||
# Create instances
|
# Create instances
|
||||||
self.file_manager = NixlFileManager(self.test_dir)
|
self.file_manager = NixlFileManager(self.test_dir)
|
||||||
self.registration = NixlRegistration(self.mock_agent)
|
self.registration = NixlRegistration(self.mock_agent)
|
||||||
|
|
||||||
|
# Create storage config for testing
|
||||||
|
self.storage_config = HiCacheStorageConfig(
|
||||||
|
tp_rank=0,
|
||||||
|
tp_size=2,
|
||||||
|
is_mla_model=False,
|
||||||
|
is_page_first_layout=False,
|
||||||
|
model_name="test_model",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
|
self.hicache = HiCacheNixl(
|
||||||
|
storage_config=self.storage_config,
|
||||||
|
file_path=self.test_dir,
|
||||||
|
plugin="POSIX",
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user