Add support for extensions of interface and pre-registrations to NIXL HiCache (#9211)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -36,6 +36,21 @@ Consolidated utility classes:
|
|||||||
- **NixlRegistration** - Manages memory registration for tensors, files and objects
|
- **NixlRegistration** - Manages memory registration for tensors, files and objects
|
||||||
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
|
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
|
||||||
|
|
||||||
|
## Using NIXL for HiCache backend
|
||||||
|
When running the SGLang server, indicate `nixl` for `hicache-storage-backend` parameter, for instance:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model-path <model> --host <ip> --port <port> --page-size 64 --enable-hierarchical-cache --hicache-ratio 2 --hicache-size 64 --hicache-write-policy write_through --hicache-storage-backend nixl
|
||||||
|
```
|
||||||
|
|
||||||
|
To customize the base directory for files, you can set the following environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR=/path/to/desired/dir
|
||||||
|
```
|
||||||
|
|
||||||
|
Selection of any storage backend like 3FS requires availability of that library on the system, and the backend is selected based on the priority mentioned above.
|
||||||
|
|
||||||
## Running Unit Tests
|
## Running Unit Tests
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
@@ -43,33 +58,26 @@ Consolidated utility classes:
|
|||||||
- PyTorch installed
|
- PyTorch installed
|
||||||
- Python 3.8+
|
- Python 3.8+
|
||||||
|
|
||||||
### Unit tests from Project root
|
### Unit tests from current directory
|
||||||
Navigate to the project root directory (`/path/to/sglang`) and run:
|
From the current directory run:
|
||||||
|
|
||||||
#### Run all NIXL tests:
|
#### Run all NIXL tests:
|
||||||
```bash
|
```bash
|
||||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict
|
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Run with verbose output:
|
#### Run with verbose output:
|
||||||
```bash
|
```bash
|
||||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict
|
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -o asyncio_mode=strict
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
|
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
|
||||||
|
|
||||||
#### Run a specific test:
|
#### Run a specific test:
|
||||||
```bash
|
```bash
|
||||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
|
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
|
||||||
```
|
```
|
||||||
|
|
||||||
### From Tests Directory
|
|
||||||
Navigate to the tests directory and run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd test/srt
|
|
||||||
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
|
|
||||||
```
|
|
||||||
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
|
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
|
||||||
|
|
||||||
## Test Coverage
|
## Test Coverage
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
|
|
||||||
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
||||||
"""Initialize NIXL storage connector."""
|
"""Initialize NIXL storage connector."""
|
||||||
|
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
||||||
|
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
||||||
self.file_manager = (
|
self.file_manager = (
|
||||||
NixlFileManager(file_path)
|
NixlFileManager(file_path)
|
||||||
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
||||||
@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
|
|
||||||
self.registration = NixlRegistration(self.agent)
|
self.registration = NixlRegistration(self.agent)
|
||||||
|
|
||||||
|
def register_buffers(
|
||||||
|
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
|
||||||
|
if isinstance(buffers[0], tuple):
|
||||||
|
tuples = [(x[0], x[1], 0, "") for x in buffers]
|
||||||
|
return self.registration._register_memory(tuples, "DRAM")
|
||||||
|
else:
|
||||||
|
return self.registration._register_memory(buffers)
|
||||||
|
|
||||||
|
def register_files(
|
||||||
|
self, file_paths: List[str], open_file: Optional[bool] = True
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Register files with NIXL."""
|
||||||
|
tuples = self.file_manager.files_to_nixl_tuples(file_paths)
|
||||||
|
return self.registration._register_memory(tuples, "FILE")
|
||||||
|
|
||||||
|
def register_objects(
|
||||||
|
self, keys: List[str], sizes: Optional[List[int]] = None
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Register objects with NIXL."""
|
||||||
|
if not keys:
|
||||||
|
return None
|
||||||
|
tuples = [(0, 0, key, "") for key in keys]
|
||||||
|
return self.registration._register_memory(tuples, "OBJ")
|
||||||
|
|
||||||
def _execute_transfer(
|
def _execute_transfer(
|
||||||
self, tensors: List[torch.Tensor], keys: List[str], direction: str
|
self,
|
||||||
|
buffers: Optional[List[torch.Tensor | tuple]],
|
||||||
|
keys: List[str],
|
||||||
|
direction: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if len(tensors) != len(keys):
|
if len(buffers) != len(keys):
|
||||||
logger.error("Mismatch between number of tensors and files/objects")
|
logger.error("Mismatch between number of tensors/buffers and files/objects")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not self.registration.register_buffers(tensors):
|
# Registering file and object keys per transfer, to be updated when
|
||||||
logger.error("Failed to register tensors")
|
# pre-registration for file and object is added to HiCache.
|
||||||
return False
|
|
||||||
|
|
||||||
# Get transfer tuples based on backend type
|
|
||||||
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
|
||||||
if self.backend_selector.mem_type == "FILE":
|
if self.backend_selector.mem_type == "FILE":
|
||||||
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
|
tuples = self.file_manager.files_to_nixl_tuples(keys)
|
||||||
if not file_tuples or not self.registration.register_files(file_tuples):
|
if not tuples or not self.registration._register_memory(tuples, "FILE"):
|
||||||
logger.error("Failed to prepare files for transfer")
|
logger.error("Failed to prepare files for transfer")
|
||||||
return False
|
return False
|
||||||
transfer_tuples = [
|
else: # mem_type == "OBJ"
|
||||||
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
|
tuples = [(0, 0, key, "") for key in keys]
|
||||||
]
|
if not tuples or not self.registration._register_memory(tuples, "OBJ"):
|
||||||
else:
|
|
||||||
if not self.registration.register_objects(keys, tensors):
|
|
||||||
logger.error("Failed to register objects")
|
logger.error("Failed to register objects")
|
||||||
return False
|
return False
|
||||||
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
|
||||||
|
|
||||||
|
# Prepare transfer descriptors
|
||||||
|
if isinstance(buffers[0], torch.Tensor):
|
||||||
|
tensor_sizes = [
|
||||||
|
tensor.element_size() * tensor.numel() for tensor in buffers
|
||||||
|
]
|
||||||
|
storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
|
||||||
|
host_descs = self.agent.get_xfer_descs(buffers)
|
||||||
|
elif isinstance(buffers[0], tuple):
|
||||||
|
storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
|
||||||
|
host_descs = self.agent.get_xfer_descs(
|
||||||
|
[(x[0], x[1], 0) for x in buffers], "DRAM"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
storage_descs = self.agent.get_xfer_descs(
|
||||||
|
storage_tuples, self.backend_selector.mem_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if (host_descs is None) or (storage_descs is None):
|
||||||
|
logger.error("Failed to get transfer descriptors")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Initialize transfer, default assumption that tensor was registered
|
||||||
try:
|
try:
|
||||||
# Get transfer descriptors
|
xfer_req = self.agent.initialize_xfer(
|
||||||
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
|
direction, host_descs, storage_descs, self.agent_name
|
||||||
file_descs := self.agent.get_xfer_descs(
|
)
|
||||||
transfer_tuples, self.backend_selector.mem_type
|
except Exception:
|
||||||
)
|
# Check if it was due to missing pre-registration
|
||||||
) is None:
|
if not self.register_buffers(buffers):
|
||||||
logger.error("Failed to get transfer descriptors")
|
logger.error("Failed to register tensors/buffers")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Initialize and execute transfer
|
try:
|
||||||
if (
|
xfer_req = self.agent.initialize_xfer(
|
||||||
xfer_req := self.agent.initialize_xfer(
|
direction, host_descs, storage_descs, self.agent_name
|
||||||
direction, tensor_descs, file_descs, self.agent_name
|
|
||||||
)
|
)
|
||||||
) is None:
|
except Exception as e:
|
||||||
logger.error("Failed to create transfer request")
|
logger.error(f"Failed to create transfer request: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Execute transfer and wait for its completion
|
||||||
|
try:
|
||||||
state = self.agent.transfer(xfer_req)
|
state = self.agent.transfer(xfer_req)
|
||||||
while state != "DONE":
|
while state != "DONE":
|
||||||
state = self.agent.check_xfer_state(xfer_req)
|
state = self.agent.check_xfer_state(xfer_req)
|
||||||
if state == "ERR":
|
if state == "ERR":
|
||||||
|
self.agent.release_xfer_handle(xfer_req)
|
||||||
logger.error("Transfer failed")
|
logger.error("Transfer failed")
|
||||||
return False
|
return False
|
||||||
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
||||||
|
|
||||||
|
self.agent.release_xfer_handle(xfer_req)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -106,46 +158,88 @@ class HiCacheNixl(HiCacheStorage):
|
|||||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
def get(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
target_location: Optional[torch.Tensor | int] = None,
|
||||||
|
target_sizes: Optional[int] = None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
# To be removed, being compatible with the current API
|
||||||
|
if target_location is None:
|
||||||
|
return None
|
||||||
|
if target_sizes:
|
||||||
|
result = self.batch_get([key], [target_location], [target_sizes])
|
||||||
|
else:
|
||||||
|
result = self.batch_get([key], [target_location])
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def batch_get(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
target_locations: Optional[List[torch.Tensor | int]] = None,
|
||||||
|
target_sizes: Optional[List[int]] = None,
|
||||||
|
) -> List[torch.Tensor | None]:
|
||||||
if not keys:
|
if not keys:
|
||||||
return True
|
return []
|
||||||
|
|
||||||
|
# To be removed, being compatible with the current API
|
||||||
|
if not target_locations:
|
||||||
|
return [None] * len(keys)
|
||||||
|
|
||||||
|
if target_sizes and (len(target_sizes) != len(target_locations)):
|
||||||
|
logger.error("Mismatch between number of target_locations and target_sizes")
|
||||||
|
return [None] * len(keys)
|
||||||
|
if target_sizes:
|
||||||
|
dest = list(zip(target_locations, target_sizes))
|
||||||
|
else:
|
||||||
|
dest = target_locations
|
||||||
|
|
||||||
|
if self.backend_selector.mem_type == "FILE":
|
||||||
|
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
||||||
|
success = self._execute_transfer(dest, file_paths, "READ")
|
||||||
|
else:
|
||||||
|
success = self._execute_transfer(dest, keys, "READ")
|
||||||
|
return target_locations if success and not target_sizes else [None] * len(keys)
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Optional[torch.Tensor] = None,
|
||||||
|
target_location: Optional[int] = None,
|
||||||
|
target_sizes: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
if target_location and target_sizes:
|
||||||
|
return self.batch_set([key], None, [target_location], [target_sizes])
|
||||||
|
else:
|
||||||
|
return self.batch_set([key], [value])
|
||||||
|
|
||||||
|
def batch_set(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
values: Optional[List[torch.Tensor]] = None,
|
||||||
|
target_locations: Optional[List[int]] = None,
|
||||||
|
target_sizes: Optional[List[int]] = None,
|
||||||
|
) -> bool:
|
||||||
|
if not keys or (not values and (not target_locations or not target_sizes)):
|
||||||
|
logger.error("Keys or values were not passed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
values = list(zip(target_locations, target_sizes))
|
||||||
|
|
||||||
if self.backend_selector.mem_type == "FILE":
|
if self.backend_selector.mem_type == "FILE":
|
||||||
file_paths = []
|
file_paths = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
tensor_path = self.file_manager.get_file_path(key)
|
file_path = self.file_manager.get_file_path(key)
|
||||||
if not self.file_manager.create_file(tensor_path):
|
# New file per set, to be updated when partial writes is added to HiCache
|
||||||
logger.error(f"Failed to create file {tensor_path}")
|
if not self.file_manager.create_file(file_path):
|
||||||
|
logger.error(f"Failed to create file {file_path}")
|
||||||
return False
|
return False
|
||||||
file_paths.append(tensor_path)
|
file_paths.append(file_path)
|
||||||
return self._execute_transfer(values, file_paths, "WRITE")
|
return self._execute_transfer(values, file_paths, "WRITE")
|
||||||
else:
|
else: # mem_type == "OBJ"
|
||||||
return self._execute_transfer(values, keys, "WRITE")
|
return self._execute_transfer(values, keys, "WRITE")
|
||||||
|
|
||||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
|
||||||
return self.batch_set([key], [value])
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self, key: str, dst_tensor: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
if dst_tensor is None: # To be removed, being compatible with the current API
|
|
||||||
return None
|
|
||||||
result = self.batch_get([key], [dst_tensor])
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
def batch_get(
|
|
||||||
self, keys: List[str], dst_tensors: List[torch.Tensor]
|
|
||||||
) -> List[Optional[torch.Tensor]]:
|
|
||||||
if not keys:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if self.backend_selector.mem_type == "FILE":
|
|
||||||
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
|
||||||
success = self._execute_transfer(dst_tensors, file_paths, "READ")
|
|
||||||
else:
|
|
||||||
success = self._execute_transfer(dst_tensors, keys, "READ")
|
|
||||||
return dst_tensors if success else [None] * len(keys)
|
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
tuples = self.registration.create_query_tuples(
|
tuples = self.registration.create_query_tuples(
|
||||||
key,
|
key,
|
||||||
|
|||||||
@@ -109,66 +109,35 @@ class NixlRegistration:
|
|||||||
return [(0, 0, key)]
|
return [(0, 0, key)]
|
||||||
|
|
||||||
def _register_memory(
|
def _register_memory(
|
||||||
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
|
self,
|
||||||
|
items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
|
||||||
|
mem_type: Optional[str] = None,
|
||||||
) -> Optional[Any]:
|
) -> Optional[Any]:
|
||||||
"""Common registration logic for files, objects, and buffers.
|
"""Common registration logic for files, objects, and buffers.
|
||||||
Args:
|
Args:
|
||||||
items: List of tuples or tensors to register
|
items: List of tuples or tensors to register
|
||||||
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
|
mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
|
||||||
desc: Description for logging
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(items, list) and not items:
|
||||||
|
return None
|
||||||
|
|
||||||
|
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
||||||
|
if reg_descs is None:
|
||||||
|
logger.error("Failed to create registration descriptors")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not items:
|
|
||||||
return None
|
|
||||||
|
|
||||||
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
|
||||||
if reg_descs is None:
|
|
||||||
logger.error("Failed to create registration descriptors")
|
|
||||||
return None
|
|
||||||
|
|
||||||
registered_memory = self.agent.register_memory(reg_descs)
|
registered_memory = self.agent.register_memory(reg_descs)
|
||||||
if registered_memory:
|
return registered_memory # Could be None in case of error
|
||||||
return registered_memory
|
|
||||||
else:
|
|
||||||
logger.error("Failed to register with NIXL")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register {desc}: {e}")
|
if not mem_type:
|
||||||
|
logger.error(f"Failed to register Tensors with NIXL: {e}")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to register memory of type {mem_type} with NIXL: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def register_buffers(
|
|
||||||
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
|
|
||||||
) -> Optional[Any]:
|
|
||||||
"""Register tensors/buffers with NIXL."""
|
|
||||||
if isinstance(buffers, torch.Tensor):
|
|
||||||
buffers = [buffers]
|
|
||||||
|
|
||||||
if not buffers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Determine memory type based on tensor device
|
|
||||||
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
|
|
||||||
return self._register_memory(buffers, mem_type, "buffers")
|
|
||||||
|
|
||||||
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
|
|
||||||
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
|
|
||||||
return self._register_memory(tuples, "FILE", "files")
|
|
||||||
|
|
||||||
def register_objects(
|
|
||||||
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
|
|
||||||
) -> Optional[Any]:
|
|
||||||
"""Register objects with NIXL."""
|
|
||||||
if not keys:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Create object tuples with proper sizes
|
|
||||||
tuples = [
|
|
||||||
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
|
||||||
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
|
||||||
]
|
|
||||||
return self._register_memory(tuples, "OBJ", "objects")
|
|
||||||
|
|
||||||
|
|
||||||
class NixlFileManager:
|
class NixlFileManager:
|
||||||
"""Handles file system operations for NIXL."""
|
"""Handles file system operations for NIXL."""
|
||||||
@@ -221,12 +190,9 @@ class NixlFileManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def files_to_nixl_tuples(
|
def files_to_nixl_tuples(
|
||||||
self, file_paths: List[str], open_file: bool = True
|
self, file_paths: List[str]
|
||||||
) -> List[Tuple[int, int, int, str]]:
|
) -> List[Tuple[int, int, int, str]]:
|
||||||
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
||||||
if not open_file:
|
|
||||||
return [(0, 0, 0, path) for path in file_paths]
|
|
||||||
|
|
||||||
tuples = []
|
tuples = []
|
||||||
for path in file_paths:
|
for path in file_paths:
|
||||||
if (fd := self.open_file(path)) is None:
|
if (fd := self.open_file(path)) is None:
|
||||||
|
|||||||
@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||||
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
|
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
||||||
|
NixlFileManager,
|
||||||
|
NixlRegistration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestNixlUnified(unittest.TestCase):
|
class TestNixlUnified(unittest.TestCase):
|
||||||
@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
|
|||||||
|
|
||||||
# Test get
|
# Test get
|
||||||
retrieved = self.hicache.get(key, dst_tensor)
|
retrieved = self.hicache.get(key, dst_tensor)
|
||||||
|
self.verify_tensors_equal(value, dst_tensor)
|
||||||
self.verify_tensors_equal(value, retrieved)
|
self.verify_tensors_equal(value, retrieved)
|
||||||
|
|
||||||
|
# Same test in addr,len mode with another key and dst_tensor
|
||||||
|
key2 = "test_key2"
|
||||||
|
dst_tensor2 = torch.zeros_like(value, device="cpu")
|
||||||
|
src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
|
||||||
|
dst_addr, dst_len = (
|
||||||
|
dst_tensor2.data_ptr(),
|
||||||
|
dst_tensor2.numel() * dst_tensor2.element_size(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test set
|
||||||
|
self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
|
||||||
|
self.assertTrue(self.hicache.exists(key))
|
||||||
|
|
||||||
|
# Test get
|
||||||
|
retrieved2 = self.hicache.get(key, dst_addr, dst_len)
|
||||||
|
self.assertTrue(retrieved2 == None)
|
||||||
|
self.verify_tensors_equal(value, dst_tensor2)
|
||||||
|
|
||||||
def test_batch_set_get(self):
|
def test_batch_set_get(self):
|
||||||
"""Test batch tensor set/get operations."""
|
"""Test batch tensor set/get operations."""
|
||||||
keys = ["key1", "key2", "key3"]
|
keys = ["key1", "key2", "key3"]
|
||||||
@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
|
|||||||
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
||||||
self.verify_tensor_lists_equal(values, retrieved)
|
self.verify_tensor_lists_equal(values, retrieved)
|
||||||
|
|
||||||
|
# Same test in addr,len mode with another key and dst_tensor
|
||||||
|
keys2 = ["key4", "key5", "key6"]
|
||||||
|
dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
|
||||||
|
src_addrs = [v.data_ptr() for v in values]
|
||||||
|
src_lens = [v.numel() * v.element_size() for v in values]
|
||||||
|
dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
|
||||||
|
dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
|
||||||
|
|
||||||
|
# Test batch set
|
||||||
|
self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
|
||||||
|
self.assertTrue(all(self.hicache.exists(key) for key in keys2))
|
||||||
|
|
||||||
|
# Test batch get
|
||||||
|
retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
|
||||||
|
self.assertTrue(all(ret == None for ret in retrieved2))
|
||||||
|
self.verify_tensor_lists_equal(values, dst_tensors2)
|
||||||
|
|
||||||
def test_mixed_operations(self):
|
def test_mixed_operations(self):
|
||||||
"""Test mixing single and batch operations."""
|
"""Test mixing single and batch operations."""
|
||||||
# Test interleaved set/get operations
|
# Test interleaved set/get operations
|
||||||
@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
|
|||||||
self.file_manager.create_file(test_file)
|
self.file_manager.create_file(test_file)
|
||||||
|
|
||||||
# Test tuple creation
|
# Test tuple creation
|
||||||
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
|
tuples = self.file_manager.files_to_nixl_tuples([test_file])
|
||||||
self.assertIsNotNone(tuples)
|
self.assertIsNotNone(tuples)
|
||||||
self.assertTrue(len(tuples) > 0)
|
self.assertTrue(len(tuples) > 0)
|
||||||
|
|
||||||
@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
|
|||||||
tensor = torch.randn(10, 10)
|
tensor = torch.randn(10, 10)
|
||||||
|
|
||||||
# Test buffer registration
|
# Test buffer registration
|
||||||
self.assertIsNotNone(self.registration.register_buffers(tensor))
|
self.assertIsNotNone(self.hicache.register_buffers(tensor))
|
||||||
|
|
||||||
# Test batch registration
|
# Test batch registration
|
||||||
tensors = [torch.randn(5, 5) for _ in range(3)]
|
tensors = [torch.randn(5, 5) for _ in range(3)]
|
||||||
self.assertIsNotNone(self.registration.register_buffers(tensors))
|
self.assertIsNotNone(self.hicache.register_buffers(tensors))
|
||||||
|
|
||||||
def test_register_files_with_tuples(self):
|
def test_register_files_with_tuples(self):
|
||||||
"""Test registration of files using NIXL tuples."""
|
"""Test registration of files using NIXL tuples."""
|
||||||
@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
|
|||||||
self.file_manager.create_file(file)
|
self.file_manager.create_file(file)
|
||||||
|
|
||||||
# Create tuples and register
|
# Create tuples and register
|
||||||
tuples = self.file_manager.files_to_nixl_tuples(files, False)
|
tuples = self.file_manager.files_to_nixl_tuples(files)
|
||||||
self.registration.register_files(tuples)
|
self.hicache.register_files(tuples)
|
||||||
|
|
||||||
# Verify tuples
|
# Verify tuples
|
||||||
self.assertEqual(len(tuples), len(files))
|
self.assertEqual(len(tuples), len(files))
|
||||||
|
|||||||
Reference in New Issue
Block a user