SGLang HiCache NIXL Connector (#8488)
Signed-off-by: Vishwanath Venkatesan <vvenkatesan@nvidia.com> Co-authored-by: Moein Khazraee <moein@nvidia.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
committed by
GitHub
parent
743638bc03
commit
2cd2e27f80
@@ -265,6 +265,11 @@ class HiCacheController:
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
self.get_hash_str = get_hash_str
|
||||
elif storage_backend == "nixl":
|
||||
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
||||
|
||||
self.storage_backend = HiCacheNixl()
|
||||
self.get_hash_str = get_hash_str
|
||||
elif storage_backend == "mooncake":
|
||||
self.storage_backend = MooncakeStore()
|
||||
self.get_hash_str = get_hash_str_mooncake
|
||||
@@ -545,7 +550,11 @@ class HiCacheController:
|
||||
def generic_page_transfer(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
page_data = self.storage_backend.batch_get(page_hashes)
|
||||
# todo: zero copy
|
||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||
page_hashes
|
||||
)
|
||||
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||
@@ -679,7 +688,7 @@ class HiCacheController:
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
page_data = [
|
||||
self.mem_pool_host.get_flat_data_pages(
|
||||
self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[j * self.page_size]
|
||||
)
|
||||
for j in range(i, i + len(page_hashes))
|
||||
|
||||
@@ -123,13 +123,22 @@ class HiCacheFile(HiCacheStorage):
|
||||
key = self._get_suffixed_key(key)
|
||||
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
||||
try:
|
||||
# todo: fixing the target_location logic to enable in-place loading
|
||||
loaded_tensor = torch.load(tensor_path)
|
||||
if isinstance(loaded_tensor, torch.Tensor):
|
||||
return loaded_tensor
|
||||
if target_location is not None:
|
||||
# Load directly into target_location's memory buffer
|
||||
with open(tensor_path, "rb") as f:
|
||||
target_location.set_(
|
||||
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
||||
.reshape(target_location.shape)
|
||||
.storage()
|
||||
)
|
||||
return target_location
|
||||
else:
|
||||
logger.error(f"Loaded data for key {key} is not a tensor.")
|
||||
return None
|
||||
loaded_tensor = torch.load(tensor_path)
|
||||
if isinstance(loaded_tensor, torch.Tensor):
|
||||
return loaded_tensor
|
||||
else:
|
||||
logger.error(f"Loaded data for key {key} is not a tensor.")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
@@ -105,6 +105,14 @@ class HostKVCache(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
"""
|
||||
Get a dummy flat data page from the host memory pool.
|
||||
This is used for prefetching or initializing empty pages.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
"""
|
||||
@@ -256,6 +264,14 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
).flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
2,
|
||||
@@ -355,6 +371,19 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
(
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
).flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
self.layer_num,
|
||||
|
||||
164
python/sglang/srt/mem_cache/nixl/README.md
Normal file
164
python/sglang/srt/mem_cache/nixl/README.md
Normal file
@@ -0,0 +1,164 @@
|
||||
# NIXL Integration for HiCache
|
||||
|
||||
This directory contains the **NIXL (NVIDIA Inference Xfer Library)** integration for **HiCache**, enabling high-performance storage across multiple backends.
|
||||
|
||||
NIXL provides a unified API for accessing various storage plugins, including but not limited to:
|
||||
|
||||
- **Deepseek's 3FS APIs** for high-throughput file operations
|
||||
- **GPU Direct Storage (GDS)** for direct data movement between storage and GPU memory, bypassing CPU memory copies
|
||||
- **Amazon S3-compatible object storage** for key-value access patterns
|
||||
|
||||
Additional backend integrations are planned for future releases.
|
||||
|
||||
## NIXL Resources
|
||||
|
||||
- **Project Repository**: [NIXL on GitHub](https://github.com/ai-dynamo/nixl)
|
||||
- **Documentation**: [NIXL Documentation](https://github.com/ai-dynamo/nixl/tree/main/docs)
|
||||
|
||||
## Overview
|
||||
|
||||
The NIXL integration consists of two main files:
|
||||
|
||||
- **`hicache_nixl.py`** - Main HiCache storage connector using NIXL
|
||||
- **`nixl_utils.py`** - Utility classes for backend selection, registration, and file management
|
||||
|
||||
## Components
|
||||
|
||||
### HiCacheNixl
|
||||
The main storage connector that provides:
|
||||
- Single and batch tensor set/get operations
|
||||
- Automatic backend selection (3FS > POSIX > GDS_MT > GDS > OBJ)
|
||||
- High-performance file-based (or) object based storage access using NIXL
|
||||
|
||||
### NixlUtils
|
||||
Consolidated utility classes:
|
||||
- **NixlBackendSelection** - Handles backend selection and creation
|
||||
- **NixlRegistration** - Manages memory registration for tensors, files and objects
|
||||
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
|
||||
|
||||
## Running Unit Tests
|
||||
|
||||
### Prerequisites
|
||||
- NIXL library installed and available (latest main required for supporting object query)
|
||||
- PyTorch installed
|
||||
- Python 3.8+
|
||||
|
||||
### Unit tests from Project root
|
||||
Navigate to the project root directory (`/path/to/sglang`) and run:
|
||||
|
||||
#### Run all NIXL tests:
|
||||
```bash
|
||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict
|
||||
```
|
||||
|
||||
#### Run with verbose output:
|
||||
```bash
|
||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict
|
||||
```
|
||||
|
||||
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
|
||||
|
||||
#### Run a specific test:
|
||||
```bash
|
||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
|
||||
```
|
||||
|
||||
### From Tests Directory
|
||||
Navigate to the tests directory and run:
|
||||
|
||||
```bash
|
||||
cd test/srt
|
||||
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
|
||||
```
|
||||
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
Tests for this integration, a test suite can be found at `test_hicache_nixl_storage.py` which covers:
|
||||
|
||||
### HiCache Integration Tests (4 tests)
|
||||
- Single tensor set/get operations
|
||||
- Batch tensor set/get operations
|
||||
- Mixed single and batch operations
|
||||
- Data integrity for various tensor types
|
||||
|
||||
### File Management Tests (5 tests)
|
||||
- Basic file operations
|
||||
- NIXL tuple creation
|
||||
- Error handling in file operations
|
||||
|
||||
### Registration Tests (2 tests)
|
||||
- Tensor registration with memory type detection
|
||||
- File registration using NIXL tuples
|
||||
|
||||
## Expected Output
|
||||
|
||||
When tests run successfully, you should see:
|
||||
- NIXL agent initialization messages
|
||||
- Backend selection messages (e.g., "Backend POSIX was instantiated")
|
||||
- Test results with "ok" for passed tests
|
||||
- Summary showing "Ran X tests in Y seconds" and "OK"
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Import Errors
|
||||
If you encounter `ModuleNotFoundError`, ensure:
|
||||
- You're running from the correct directory
|
||||
- `PYTHONPATH` is set correctly
|
||||
- NIXL library is properly installed
|
||||
|
||||
### NIXL Errors
|
||||
If NIXL operations fail:
|
||||
- Check that NIXL is properly installed
|
||||
- Verify that required plugins are available
|
||||
- Ensure file permissions are correct for test directories
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
python/sglang/srt/mem_cache/nixl/
|
||||
├── hicache_nixl.py # Main HiCache storage connector
|
||||
├── nixl_utils.py # All NIXL utility classes
|
||||
├── README.md # This file
|
||||
└── tests/
|
||||
└── test_nixl_unified.py # All tests in one file
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **NIXL**: NVIDIA Inference Xfer Library (version 0.4 or later)
|
||||
- Required plugins: POSIX (minimum), 3FS/GDS (optional for better performance)
|
||||
- See [NIXL Installation Guide](https://github.com/ai-dynamo/nixl/blob/main/README.md#installation)
|
||||
- **PyTorch**: For tensor operations (version 1.8 or later)
|
||||
- **Python 3.8+**: For type hints and modern features
|
||||
|
||||
## Supported Features
|
||||
|
||||
### Memory Types
|
||||
- **Tensor side**: multi-dimensional tensors of all numeric types (int32, int64, float32, float64) are supported.
|
||||
- Tensors can be on CPU or GPU (as long as a GPU capable backend such as GDS_MT is available).
|
||||
- Currently each tensor is mapped to a file or key, but it can be extended to support multiple keys per file or key.
|
||||
|
||||
- **Storage side**: file and object are supported through their relevant backends (e.g., 3FS or OBJ).
|
||||
|
||||
### Backend Priority
|
||||
|
||||
The NIXL backend selection follows this priority order:
|
||||
1. **3FS** - Highest performance (if available)
|
||||
- Best for high-throughput file operations using Deepseek 3FS APIs
|
||||
2. **POSIX** - Standard file I/O (fallback)
|
||||
- Universal compatibility
|
||||
- Good for development and testing - Leverages both libaio/liburing
|
||||
3. **GDS_MT** - Multi-threaded GDS (if available)
|
||||
- Optimized for concurrent operations
|
||||
- Supports GPU Direct storage with multiple light weight threads
|
||||
4. **GDS** - GPU Direct Storage (if available)
|
||||
- Direct GPU-storage data path
|
||||
- Best for filesystems benefiting from batch operations and smaller IOs.
|
||||
5. **OBJ** - Amazon S3 based Object Storage
|
||||
- Key-value based storage
|
||||
The system automatically selects the best available backend, with POSIX as the default fallback.
|
||||
|
||||
## Note
|
||||
|
||||
This is v0 of the NIXL connector. Future versions will focus on further performance optimizations such as memory pre-registration (pre-allocating and registering memory buffers to reduce registration overhead during transfers) and block merging (combining related blocks as offsets within the same file to reduce file operations and improve throughput). These optimizations require changes at a higher layer, as the current HiCache API doesn't expose information like block relationships or hash patterns that would enable these optimizations.
|
||||
163
python/sglang/srt/mem_cache/nixl/hicache_nixl.py
Normal file
163
python/sglang/srt/mem_cache/nixl/hicache_nixl.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||
|
||||
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent, nixl_agent_config
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install NIXL by following the instructions at "
|
||||
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
|
||||
"to use HiCacheNixl storage backend."
|
||||
) from e
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HiCacheNixl(HiCacheStorage):
|
||||
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
||||
|
||||
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
||||
"""Initialize NIXL storage connector."""
|
||||
self.file_manager = (
|
||||
NixlFileManager(file_path)
|
||||
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
||||
else None
|
||||
)
|
||||
|
||||
agent_config = nixl_agent_config(backends=[])
|
||||
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
||||
self.agent = nixl_agent(self.agent_name, agent_config)
|
||||
|
||||
self.backend_selector = NixlBackendSelection(plugin)
|
||||
if not self.backend_selector.create_backend(self.agent):
|
||||
raise RuntimeError("Failed to create NIXL backend")
|
||||
|
||||
self.registration = NixlRegistration(self.agent)
|
||||
|
||||
def _execute_transfer(
|
||||
self, tensors: List[torch.Tensor], keys: List[str], direction: str
|
||||
) -> bool:
|
||||
if len(tensors) != len(keys):
|
||||
logger.error("Mismatch between number of tensors and files/objects")
|
||||
return False
|
||||
|
||||
if not self.registration.register_buffers(tensors):
|
||||
logger.error("Failed to register tensors")
|
||||
return False
|
||||
|
||||
# Get transfer tuples based on backend type
|
||||
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
|
||||
if not file_tuples or not self.registration.register_files(file_tuples):
|
||||
logger.error("Failed to prepare files for transfer")
|
||||
return False
|
||||
transfer_tuples = [
|
||||
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
|
||||
]
|
||||
else:
|
||||
if not self.registration.register_objects(keys, tensors):
|
||||
logger.error("Failed to register objects")
|
||||
return False
|
||||
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
||||
|
||||
try:
|
||||
# Get transfer descriptors
|
||||
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
|
||||
file_descs := self.agent.get_xfer_descs(
|
||||
transfer_tuples, self.backend_selector.mem_type
|
||||
)
|
||||
) is None:
|
||||
logger.error("Failed to get transfer descriptors")
|
||||
return False
|
||||
|
||||
# Initialize and execute transfer
|
||||
if (
|
||||
xfer_req := self.agent.initialize_xfer(
|
||||
direction, tensor_descs, file_descs, self.agent_name
|
||||
)
|
||||
) is None:
|
||||
logger.error("Failed to create transfer request")
|
||||
return False
|
||||
|
||||
state = self.agent.transfer(xfer_req)
|
||||
while state != "DONE":
|
||||
state = self.agent.check_xfer_state(xfer_req)
|
||||
if state == "ERR":
|
||||
logger.error("Transfer failed")
|
||||
return False
|
||||
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute transfer: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
if not keys:
|
||||
return True
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = []
|
||||
for key in keys:
|
||||
tensor_path = self.file_manager.get_file_path(key)
|
||||
if not self.file_manager.create_file(tensor_path):
|
||||
logger.error(f"Failed to create file {tensor_path}")
|
||||
return False
|
||||
file_paths.append(tensor_path)
|
||||
return self._execute_transfer(values, file_paths, "WRITE")
|
||||
else:
|
||||
return self._execute_transfer(values, keys, "WRITE")
|
||||
|
||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||
return self.batch_set([key], [value])
|
||||
|
||||
def get(
|
||||
self, key: str, dst_tensor: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
if dst_tensor is None: # To be removed, being compatible with the current API
|
||||
return None
|
||||
result = self.batch_get([key], [dst_tensor])
|
||||
return result[0] if result else None
|
||||
|
||||
def batch_get(
|
||||
self, keys: List[str], dst_tensors: List[torch.Tensor]
|
||||
) -> List[Optional[torch.Tensor]]:
|
||||
if not keys:
|
||||
return []
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
||||
success = self._execute_transfer(dst_tensors, file_paths, "READ")
|
||||
else:
|
||||
success = self._execute_transfer(dst_tensors, keys, "READ")
|
||||
return dst_tensors if success else [None] * len(keys)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
tuples = self.registration.create_query_tuples(
|
||||
key,
|
||||
self.backend_selector.mem_type,
|
||||
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
||||
)
|
||||
if not tuples:
|
||||
return False
|
||||
|
||||
query_res = self.agent.query_memory(
|
||||
tuples,
|
||||
self.backend_selector.backend_name,
|
||||
mem_type=self.backend_selector.mem_type,
|
||||
)
|
||||
return query_res[0] is not None # can be expanded to multiple keys
|
||||
238
python/sglang/srt/mem_cache/nixl/nixl_utils.py
Normal file
238
python/sglang/srt/mem_cache/nixl/nixl_utils.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NixlBackendSelection:
|
||||
"""Handles NIXL backend selection and creation."""
|
||||
|
||||
# Priority order for File-based plugins in case of auto selection
|
||||
FILE_PLUGINS = ["3FS", "POSIX", "GDS_MT", "GDS"]
|
||||
# Priority order for File-based plugins in case of auto selection (add more as needed)
|
||||
OBJ_PLUGINS = ["OBJ"] # Based on Amazon S3 SDK
|
||||
|
||||
def __init__(self, plugin: str = "auto"):
|
||||
"""Initialize backend selection.
|
||||
Args:
|
||||
plugin: Plugin to use (default "auto" selects best available).
|
||||
Can be a file plugin (3FS, POSIX, GDS, GDS_MT) or
|
||||
an object plugin (OBJ).
|
||||
"""
|
||||
self.plugin = plugin
|
||||
self.backend_name = None
|
||||
self.mem_type = None
|
||||
|
||||
def set_bucket(self, bucket_name: str) -> None:
|
||||
"""Set AWS bucket name in environment variable."""
|
||||
os.environ["AWS_DEFAULT_BUCKET"] = bucket_name
|
||||
logger.debug(f"Set AWS bucket name to: {bucket_name}")
|
||||
|
||||
def create_backend(self, agent) -> bool:
|
||||
"""Create the appropriate NIXL backend based on configuration."""
|
||||
try:
|
||||
plugin_list = agent.get_plugin_list()
|
||||
logger.debug(f"Available NIXL plugins: {plugin_list}")
|
||||
|
||||
# Handle explicit plugin selection or auto priority
|
||||
if self.plugin == "auto":
|
||||
# Try all file plugins first
|
||||
for plugin in self.FILE_PLUGINS:
|
||||
if plugin in plugin_list:
|
||||
self.backend_name = plugin
|
||||
break
|
||||
# If no file plugin found, try object plugins
|
||||
if not self.backend_name:
|
||||
for plugin in self.OBJ_PLUGINS:
|
||||
if plugin in plugin_list:
|
||||
self.backend_name = plugin
|
||||
break
|
||||
else:
|
||||
# Use explicitly requested plugin
|
||||
self.backend_name = self.plugin
|
||||
|
||||
if self.backend_name not in plugin_list:
|
||||
logger.error(
|
||||
f"Backend {self.backend_name} not available in plugins: {plugin_list}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Create backend and set memory type
|
||||
if self.backend_name in self.OBJ_PLUGINS:
|
||||
bucket = os.environ.get("AWS_DEFAULT_BUCKET")
|
||||
if not bucket:
|
||||
logger.error(
|
||||
"AWS_DEFAULT_BUCKET environment variable must be set for object storage"
|
||||
)
|
||||
return False
|
||||
agent.create_backend(self.backend_name, {"bucket": bucket})
|
||||
else:
|
||||
agent.create_backend(self.backend_name)
|
||||
|
||||
self.mem_type = "OBJ" if self.backend_name in self.OBJ_PLUGINS else "FILE"
|
||||
logger.debug(
|
||||
f"Created NIXL backend: {self.backend_name} with memory type: {self.mem_type}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create NIXL backend: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class NixlRegistration:
|
||||
"""Handles NIXL memory registration."""
|
||||
|
||||
def __init__(self, agent):
|
||||
self.agent = agent
|
||||
|
||||
def create_query_tuples(
|
||||
self, key: str, mem_type: str, file_manager=None
|
||||
) -> List[Tuple]:
|
||||
"""Create NIXL tuples for querying memory.
|
||||
Args:
|
||||
key: Key to query (file path for FILE or object key for OBJ)
|
||||
mem_type: Memory type ("FILE" or "OBJ")
|
||||
file_manager: Optional NixlFileManager for FILE memory type
|
||||
Returns:
|
||||
List of NIXL tuples for querying
|
||||
"""
|
||||
if mem_type == "FILE":
|
||||
if file_manager is None:
|
||||
logger.error("file_manager required for FILE memory type")
|
||||
return []
|
||||
return [(0, 0, 0, file_manager.get_file_path(key))]
|
||||
else: # OBJ
|
||||
return [(0, 0, key)]
|
||||
|
||||
def _register_memory(
|
||||
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
|
||||
) -> Optional[Any]:
|
||||
"""Common registration logic for files, objects, and buffers.
|
||||
Args:
|
||||
items: List of tuples or tensors to register
|
||||
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
|
||||
desc: Description for logging
|
||||
"""
|
||||
try:
|
||||
if not items:
|
||||
return None
|
||||
|
||||
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
||||
if reg_descs is None:
|
||||
logger.error("Failed to create registration descriptors")
|
||||
return None
|
||||
|
||||
registered_memory = self.agent.register_memory(reg_descs)
|
||||
if registered_memory:
|
||||
return registered_memory
|
||||
else:
|
||||
logger.error("Failed to register with NIXL")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register {desc}: {e}")
|
||||
return None
|
||||
|
||||
def register_buffers(
|
||||
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Optional[Any]:
|
||||
"""Register tensors/buffers with NIXL."""
|
||||
if isinstance(buffers, torch.Tensor):
|
||||
buffers = [buffers]
|
||||
|
||||
if not buffers:
|
||||
return None
|
||||
|
||||
# Determine memory type based on tensor device
|
||||
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
|
||||
return self._register_memory(buffers, mem_type, "buffers")
|
||||
|
||||
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
|
||||
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
|
||||
return self._register_memory(tuples, "FILE", "files")
|
||||
|
||||
def register_objects(
|
||||
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
|
||||
) -> Optional[Any]:
|
||||
"""Register objects with NIXL."""
|
||||
if not keys:
|
||||
return None
|
||||
|
||||
# Create object tuples with proper sizes
|
||||
tuples = [
|
||||
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
||||
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
||||
]
|
||||
return self._register_memory(tuples, "OBJ", "objects")
|
||||
|
||||
|
||||
class NixlFileManager:
|
||||
"""Handles file system operations for NIXL."""
|
||||
|
||||
def __init__(self, base_dir: str):
|
||||
"""
|
||||
Initialize file manager.
|
||||
Args:
|
||||
base_dir: Base directory for storing tensor files
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
if base_dir == "":
|
||||
logger.debug(f"Initialized file manager without a base directory")
|
||||
else:
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
logger.debug(f"Initialized file manager with base directory: {base_dir}")
|
||||
|
||||
def get_file_path(self, key: str) -> str:
|
||||
"""Get full file path for a given key."""
|
||||
return os.path.join(self.base_dir, key)
|
||||
|
||||
def create_file(self, file_path: str) -> bool:
|
||||
"""Create a file if it doesn't exist."""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
if not os.path.exists(file_path):
|
||||
with open(file_path, "wb") as f:
|
||||
pass # Create empty file
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def open_file(self, file_path: str) -> Optional[int]:
|
||||
"""Open a file and return its file descriptor."""
|
||||
try:
|
||||
fd = os.open(file_path, os.O_RDWR)
|
||||
return fd
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def close_file(self, fd: int) -> bool:
|
||||
"""Close a file descriptor."""
|
||||
try:
|
||||
os.close(fd)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close file descriptor {fd}: {e}")
|
||||
return False
|
||||
|
||||
def files_to_nixl_tuples(
|
||||
self, file_paths: List[str], open_file: bool = True
|
||||
) -> List[Tuple[int, int, int, str]]:
|
||||
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
||||
if not open_file:
|
||||
return [(0, 0, 0, path) for path in file_paths]
|
||||
|
||||
tuples = []
|
||||
for path in file_paths:
|
||||
if (fd := self.open_file(path)) is None:
|
||||
# Clean up on failure
|
||||
for t in tuples:
|
||||
self.close_file(t[2])
|
||||
return []
|
||||
tuples.append((0, 0, fd, path))
|
||||
return tuples
|
||||
216
python/sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py
Executable file
216
python/sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py
Executable file
@@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
||||
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
|
||||
|
||||
|
||||
class TestNixlUnified(unittest.TestCase):
|
||||
"""Unified test suite for all NIXL components."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
# Create test directories
|
||||
self.test_dir = "/tmp/test_nixl_unified"
|
||||
os.makedirs(self.test_dir, exist_ok=True)
|
||||
|
||||
# Mock NIXL agent for registration tests
|
||||
self.mock_agent = MagicMock()
|
||||
self.mock_agent.get_reg_descs.return_value = "mock_reg_descs"
|
||||
self.mock_agent.register_memory.return_value = "mock_registered_memory"
|
||||
|
||||
# Create instances
|
||||
self.file_manager = NixlFileManager(self.test_dir)
|
||||
self.registration = NixlRegistration(self.mock_agent)
|
||||
try:
|
||||
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
|
||||
except ImportError:
|
||||
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test directories."""
|
||||
if os.path.exists(self.test_dir):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def delete_test_file(self, file_path: str) -> bool:
|
||||
"""Helper method to delete a test file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to delete
|
||||
|
||||
Returns:
|
||||
bool: True if file was deleted or didn't exist, False on error
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def verify_tensors_equal(self, expected: torch.Tensor, actual: torch.Tensor):
|
||||
"""Helper to verify tensor equality."""
|
||||
self.assertIsNotNone(actual, "Retrieved tensor is None")
|
||||
self.assertTrue(
|
||||
torch.allclose(expected, actual, atol=1e-6),
|
||||
f"Tensors not equal:\nExpected: {expected}\nActual: {actual}",
|
||||
)
|
||||
|
||||
def verify_tensor_lists_equal(
|
||||
self, expected: List[torch.Tensor], actual: List[torch.Tensor]
|
||||
):
|
||||
"""Helper to verify lists of tensors are equal."""
|
||||
self.assertEqual(len(expected), len(actual), "Lists have different lengths")
|
||||
for exp, act in zip(expected, actual):
|
||||
self.verify_tensors_equal(exp, act)
|
||||
|
||||
# ============================================================================
|
||||
# HiCache Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
def test_single_set_get(self):
|
||||
"""Test single tensor set/get operations."""
|
||||
key = "test_key"
|
||||
value = torch.randn(10, 10, device="cpu")
|
||||
dst_tensor = torch.zeros_like(value, device="cpu")
|
||||
|
||||
# Test set
|
||||
self.assertTrue(self.hicache.set(key, value))
|
||||
self.assertTrue(self.hicache.exists(key))
|
||||
|
||||
# Test get
|
||||
retrieved = self.hicache.get(key, dst_tensor)
|
||||
self.verify_tensors_equal(value, retrieved)
|
||||
|
||||
def test_batch_set_get(self):
|
||||
"""Test batch tensor set/get operations."""
|
||||
keys = ["key1", "key2", "key3"]
|
||||
values = [
|
||||
torch.randn(5, 5, device="cpu"),
|
||||
torch.randn(3, 3, device="cpu"),
|
||||
torch.randn(7, 7, device="cpu"),
|
||||
]
|
||||
dst_tensors = [torch.zeros_like(v, device="cpu") for v in values]
|
||||
|
||||
# Test batch set
|
||||
self.assertTrue(self.hicache.batch_set(keys, values))
|
||||
self.assertTrue(all(self.hicache.exists(key) for key in keys))
|
||||
|
||||
# Test batch get
|
||||
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
||||
self.verify_tensor_lists_equal(values, retrieved)
|
||||
|
||||
def test_mixed_operations(self):
|
||||
"""Test mixing single and batch operations."""
|
||||
# Test interleaved set/get operations
|
||||
key1, key2 = "key1", "key2"
|
||||
value1 = torch.randn(4, 4, device="cpu")
|
||||
value2 = torch.randn(6, 6, device="cpu")
|
||||
dst1 = torch.zeros_like(value1)
|
||||
dst2 = torch.zeros_like(value2)
|
||||
|
||||
# Single set/get
|
||||
self.assertTrue(self.hicache.set(key1, value1))
|
||||
retrieved1 = self.hicache.get(key1, dst1)
|
||||
self.verify_tensors_equal(value1, retrieved1)
|
||||
|
||||
# Batch set/get
|
||||
self.assertTrue(self.hicache.batch_set([key2], [value2]))
|
||||
retrieved2 = self.hicache.batch_get([key2], [dst2])
|
||||
self.verify_tensors_equal(value2, retrieved2[0])
|
||||
|
||||
def test_data_integrity(self):
|
||||
"""Test data integrity across operations."""
|
||||
# Test with various tensor types and sizes
|
||||
test_cases = [
|
||||
("float32", torch.randn(10, 10, dtype=torch.float32)),
|
||||
("float64", torch.randn(5, 5, dtype=torch.float64)),
|
||||
("int32", torch.randint(-100, 100, (8, 8), dtype=torch.int32)),
|
||||
("int64", torch.randint(-100, 100, (6, 6), dtype=torch.int64)),
|
||||
("bool", torch.randint(0, 2, (4, 4)).bool()),
|
||||
]
|
||||
|
||||
for name, tensor in test_cases:
|
||||
with self.subTest(tensor_type=name):
|
||||
key = f"test_{name}"
|
||||
dst_tensor = torch.zeros_like(tensor)
|
||||
|
||||
# Set and immediately get
|
||||
self.assertTrue(self.hicache.set(key, tensor))
|
||||
retrieved1 = self.hicache.get(key, dst_tensor)
|
||||
self.verify_tensors_equal(tensor, retrieved1)
|
||||
|
||||
# Get again to verify persistence
|
||||
dst_tensor.zero_()
|
||||
retrieved2 = self.hicache.get(key, dst_tensor)
|
||||
self.verify_tensors_equal(tensor, retrieved2)
|
||||
|
||||
def test_basic_file_operations(self):
|
||||
"""Test basic file operations."""
|
||||
test_file = os.path.join(self.test_dir, "test_file.bin")
|
||||
self.file_manager.create_file(test_file)
|
||||
self.assertTrue(os.path.exists(test_file))
|
||||
self.assertEqual(os.path.getsize(test_file), 0) # Empty file
|
||||
|
||||
# Test file deletion
|
||||
self.assertTrue(self.delete_test_file(test_file))
|
||||
self.assertFalse(os.path.exists(test_file))
|
||||
|
||||
def test_create_nixl_tuples(self):
|
||||
"""Test creation of NIXL tuples."""
|
||||
test_file = os.path.join(self.test_dir, "test_file.bin")
|
||||
self.file_manager.create_file(test_file)
|
||||
|
||||
# Test tuple creation
|
||||
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
|
||||
self.assertIsNotNone(tuples)
|
||||
self.assertTrue(len(tuples) > 0)
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test error handling in file operations."""
|
||||
# Test non-existent file
|
||||
self.assertTrue(
|
||||
self.delete_test_file("nonexistent_file.bin")
|
||||
) # Returns True if file doesn't exist
|
||||
|
||||
# Test invalid file path
|
||||
self.assertFalse(self.file_manager.create_file("")) # Empty path should fail
|
||||
|
||||
def test_register_buffers(self):
|
||||
"""Test registration of memory buffers."""
|
||||
# Create test tensor
|
||||
tensor = torch.randn(10, 10)
|
||||
|
||||
# Test buffer registration
|
||||
self.assertIsNotNone(self.registration.register_buffers(tensor))
|
||||
|
||||
# Test batch registration
|
||||
tensors = [torch.randn(5, 5) for _ in range(3)]
|
||||
self.assertIsNotNone(self.registration.register_buffers(tensors))
|
||||
|
||||
def test_register_files_with_tuples(self):
|
||||
"""Test registration of files using NIXL tuples."""
|
||||
files = [os.path.join(self.test_dir, f"test_file_{i}.bin") for i in range(3)]
|
||||
for file in files:
|
||||
self.file_manager.create_file(file)
|
||||
|
||||
# Create tuples and register
|
||||
tuples = self.file_manager.files_to_nixl_tuples(files, False)
|
||||
self.registration.register_files(tuples)
|
||||
|
||||
# Verify tuples
|
||||
self.assertEqual(len(tuples), len(files))
|
||||
for t, f in zip(tuples, files):
|
||||
self.assertEqual(t[3], f) # Check file path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1471,7 +1471,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
choices=["file", "mooncake", "hf3fs"],
|
||||
choices=["file", "mooncake", "hf3fs", "nixl"],
|
||||
default=ServerArgs.hicache_storage_backend,
|
||||
help="The storage backend for hierarchical KV cache.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user