Add support for extensions of interface and pre-registrations to NIXL HiCache (#9211)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -36,6 +36,21 @@ Consolidated utility classes:
|
||||
- **NixlRegistration** - Manages memory registration for tensors, files and objects
|
||||
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
|
||||
|
||||
## Using NIXL for HiCache backend
|
||||
When running the SGLang server, indicate `nixl` for `hicache-storage-backend` parameter, for instance:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model-path <model> --host <ip> --port <port> --page-size 64 --enable-hierarchical-cache --hicache-ratio 2 --hicache-size 64 --hicache-write-policy write_through --hicache-storage-backend nixl
|
||||
```
|
||||
|
||||
To customize the base directory for files, you can set the following environment variable:
|
||||
|
||||
```bash
|
||||
export SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR=/path/to/desired/dir
|
||||
```
|
||||
|
||||
Selection of any storage backend like 3FS requires availability of that library on the system, and the backend is selected based on the priority mentioned above.
|
||||
|
||||
## Running Unit Tests
|
||||
|
||||
### Prerequisites
|
||||
@@ -43,33 +58,26 @@ Consolidated utility classes:
|
||||
- PyTorch installed
|
||||
- Python 3.8+
|
||||
|
||||
### Unit tests from Project root
|
||||
Navigate to the project root directory (`/path/to/sglang`) and run:
|
||||
### Unit tests from current directory
|
||||
From the current directory run:
|
||||
|
||||
#### Run all NIXL tests:
|
||||
```bash
|
||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict
|
||||
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
|
||||
```
|
||||
|
||||
#### Run with verbose output:
|
||||
```bash
|
||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict
|
||||
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -o asyncio_mode=strict
|
||||
```
|
||||
|
||||
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
|
||||
|
||||
#### Run a specific test:
|
||||
```bash
|
||||
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
|
||||
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
|
||||
```
|
||||
|
||||
### From Tests Directory
|
||||
Navigate to the tests directory and run:
|
||||
|
||||
```bash
|
||||
cd test/srt
|
||||
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
|
||||
```
|
||||
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
|
||||
|
||||
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
||||
"""Initialize NIXL storage connector."""
|
||||
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
||||
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
||||
self.file_manager = (
|
||||
NixlFileManager(file_path)
|
||||
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
||||
@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
|
||||
|
||||
self.registration = NixlRegistration(self.agent)
|
||||
|
||||
def register_buffers(
|
||||
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
||||
) -> Optional[Any]:
|
||||
"""Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
|
||||
if isinstance(buffers[0], tuple):
|
||||
tuples = [(x[0], x[1], 0, "") for x in buffers]
|
||||
return self.registration._register_memory(tuples, "DRAM")
|
||||
else:
|
||||
return self.registration._register_memory(buffers)
|
||||
|
||||
def register_files(
|
||||
self, file_paths: List[str], open_file: Optional[bool] = True
|
||||
) -> Optional[Any]:
|
||||
"""Register files with NIXL."""
|
||||
tuples = self.file_manager.files_to_nixl_tuples(file_paths)
|
||||
return self.registration._register_memory(tuples, "FILE")
|
||||
|
||||
def register_objects(
|
||||
self, keys: List[str], sizes: Optional[List[int]] = None
|
||||
) -> Optional[Any]:
|
||||
"""Register objects with NIXL."""
|
||||
if not keys:
|
||||
return None
|
||||
tuples = [(0, 0, key, "") for key in keys]
|
||||
return self.registration._register_memory(tuples, "OBJ")
|
||||
|
||||
def _execute_transfer(
|
||||
self, tensors: List[torch.Tensor], keys: List[str], direction: str
|
||||
self,
|
||||
buffers: Optional[List[torch.Tensor | tuple]],
|
||||
keys: List[str],
|
||||
direction: str,
|
||||
) -> bool:
|
||||
if len(tensors) != len(keys):
|
||||
logger.error("Mismatch between number of tensors and files/objects")
|
||||
if len(buffers) != len(keys):
|
||||
logger.error("Mismatch between number of tensors/buffers and files/objects")
|
||||
return False
|
||||
|
||||
if not self.registration.register_buffers(tensors):
|
||||
logger.error("Failed to register tensors")
|
||||
return False
|
||||
|
||||
# Get transfer tuples based on backend type
|
||||
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
||||
# Registering file and object keys per transfer, to be updated when
|
||||
# pre-registration for file and object is added to HiCache.
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
|
||||
if not file_tuples or not self.registration.register_files(file_tuples):
|
||||
tuples = self.file_manager.files_to_nixl_tuples(keys)
|
||||
if not tuples or not self.registration._register_memory(tuples, "FILE"):
|
||||
logger.error("Failed to prepare files for transfer")
|
||||
return False
|
||||
transfer_tuples = [
|
||||
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
|
||||
]
|
||||
else:
|
||||
if not self.registration.register_objects(keys, tensors):
|
||||
else: # mem_type == "OBJ"
|
||||
tuples = [(0, 0, key, "") for key in keys]
|
||||
if not tuples or not self.registration._register_memory(tuples, "OBJ"):
|
||||
logger.error("Failed to register objects")
|
||||
return False
|
||||
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
||||
|
||||
# Prepare transfer descriptors
|
||||
if isinstance(buffers[0], torch.Tensor):
|
||||
tensor_sizes = [
|
||||
tensor.element_size() * tensor.numel() for tensor in buffers
|
||||
]
|
||||
storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
|
||||
host_descs = self.agent.get_xfer_descs(buffers)
|
||||
elif isinstance(buffers[0], tuple):
|
||||
storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
|
||||
host_descs = self.agent.get_xfer_descs(
|
||||
[(x[0], x[1], 0) for x in buffers], "DRAM"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
storage_descs = self.agent.get_xfer_descs(
|
||||
storage_tuples, self.backend_selector.mem_type
|
||||
)
|
||||
|
||||
if (host_descs is None) or (storage_descs is None):
|
||||
logger.error("Failed to get transfer descriptors")
|
||||
return False
|
||||
|
||||
# Initialize transfer, default assumption that tensor was registered
|
||||
try:
|
||||
# Get transfer descriptors
|
||||
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
|
||||
file_descs := self.agent.get_xfer_descs(
|
||||
transfer_tuples, self.backend_selector.mem_type
|
||||
)
|
||||
) is None:
|
||||
logger.error("Failed to get transfer descriptors")
|
||||
xfer_req = self.agent.initialize_xfer(
|
||||
direction, host_descs, storage_descs, self.agent_name
|
||||
)
|
||||
except Exception:
|
||||
# Check if it was due to missing pre-registration
|
||||
if not self.register_buffers(buffers):
|
||||
logger.error("Failed to register tensors/buffers")
|
||||
return False
|
||||
|
||||
# Initialize and execute transfer
|
||||
if (
|
||||
xfer_req := self.agent.initialize_xfer(
|
||||
direction, tensor_descs, file_descs, self.agent_name
|
||||
try:
|
||||
xfer_req = self.agent.initialize_xfer(
|
||||
direction, host_descs, storage_descs, self.agent_name
|
||||
)
|
||||
) is None:
|
||||
logger.error("Failed to create transfer request")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create transfer request: {e}")
|
||||
return False
|
||||
|
||||
# Execute transfer and wait for its completion
|
||||
try:
|
||||
state = self.agent.transfer(xfer_req)
|
||||
while state != "DONE":
|
||||
state = self.agent.check_xfer_state(xfer_req)
|
||||
if state == "ERR":
|
||||
self.agent.release_xfer_handle(xfer_req)
|
||||
logger.error("Transfer failed")
|
||||
return False
|
||||
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
||||
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
||||
|
||||
self.agent.release_xfer_handle(xfer_req)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -106,46 +158,88 @@ class HiCacheNixl(HiCacheStorage):
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[torch.Tensor | int] = None,
|
||||
target_sizes: Optional[int] = None,
|
||||
) -> torch.Tensor | None:
|
||||
# To be removed, being compatible with the current API
|
||||
if target_location is None:
|
||||
return None
|
||||
if target_sizes:
|
||||
result = self.batch_get([key], [target_location], [target_sizes])
|
||||
else:
|
||||
result = self.batch_get([key], [target_location])
|
||||
return result[0] if result else None
|
||||
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[torch.Tensor | int]] = None,
|
||||
target_sizes: Optional[List[int]] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
if not keys:
|
||||
return True
|
||||
return []
|
||||
|
||||
# To be removed, being compatible with the current API
|
||||
if not target_locations:
|
||||
return [None] * len(keys)
|
||||
|
||||
if target_sizes and (len(target_sizes) != len(target_locations)):
|
||||
logger.error("Mismatch between number of target_locations and target_sizes")
|
||||
return [None] * len(keys)
|
||||
if target_sizes:
|
||||
dest = list(zip(target_locations, target_sizes))
|
||||
else:
|
||||
dest = target_locations
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
||||
success = self._execute_transfer(dest, file_paths, "READ")
|
||||
else:
|
||||
success = self._execute_transfer(dest, keys, "READ")
|
||||
return target_locations if success and not target_sizes else [None] * len(keys)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor] = None,
|
||||
target_location: Optional[int] = None,
|
||||
target_sizes: Optional[int] = None,
|
||||
) -> bool:
|
||||
if target_location and target_sizes:
|
||||
return self.batch_set([key], None, [target_location], [target_sizes])
|
||||
else:
|
||||
return self.batch_set([key], [value])
|
||||
|
||||
def batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
values: Optional[List[torch.Tensor]] = None,
|
||||
target_locations: Optional[List[int]] = None,
|
||||
target_sizes: Optional[List[int]] = None,
|
||||
) -> bool:
|
||||
if not keys or (not values and (not target_locations or not target_sizes)):
|
||||
logger.error("Keys or values were not passed")
|
||||
return False
|
||||
|
||||
if not values:
|
||||
values = list(zip(target_locations, target_sizes))
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = []
|
||||
for key in keys:
|
||||
tensor_path = self.file_manager.get_file_path(key)
|
||||
if not self.file_manager.create_file(tensor_path):
|
||||
logger.error(f"Failed to create file {tensor_path}")
|
||||
file_path = self.file_manager.get_file_path(key)
|
||||
# New file per set, to be updated when partial writes is added to HiCache
|
||||
if not self.file_manager.create_file(file_path):
|
||||
logger.error(f"Failed to create file {file_path}")
|
||||
return False
|
||||
file_paths.append(tensor_path)
|
||||
file_paths.append(file_path)
|
||||
return self._execute_transfer(values, file_paths, "WRITE")
|
||||
else:
|
||||
else: # mem_type == "OBJ"
|
||||
return self._execute_transfer(values, keys, "WRITE")
|
||||
|
||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||
return self.batch_set([key], [value])
|
||||
|
||||
def get(
|
||||
self, key: str, dst_tensor: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
if dst_tensor is None: # To be removed, being compatible with the current API
|
||||
return None
|
||||
result = self.batch_get([key], [dst_tensor])
|
||||
return result[0] if result else None
|
||||
|
||||
def batch_get(
|
||||
self, keys: List[str], dst_tensors: List[torch.Tensor]
|
||||
) -> List[Optional[torch.Tensor]]:
|
||||
if not keys:
|
||||
return []
|
||||
|
||||
if self.backend_selector.mem_type == "FILE":
|
||||
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
||||
success = self._execute_transfer(dst_tensors, file_paths, "READ")
|
||||
else:
|
||||
success = self._execute_transfer(dst_tensors, keys, "READ")
|
||||
return dst_tensors if success else [None] * len(keys)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
tuples = self.registration.create_query_tuples(
|
||||
key,
|
||||
|
||||
@@ -109,66 +109,35 @@ class NixlRegistration:
|
||||
return [(0, 0, key)]
|
||||
|
||||
def _register_memory(
|
||||
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
|
||||
self,
|
||||
items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
|
||||
mem_type: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""Common registration logic for files, objects, and buffers.
|
||||
Args:
|
||||
items: List of tuples or tensors to register
|
||||
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
|
||||
desc: Description for logging
|
||||
mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
|
||||
"""
|
||||
if isinstance(items, list) and not items:
|
||||
return None
|
||||
|
||||
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
||||
if reg_descs is None:
|
||||
logger.error("Failed to create registration descriptors")
|
||||
return None
|
||||
|
||||
try:
|
||||
if not items:
|
||||
return None
|
||||
|
||||
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
||||
if reg_descs is None:
|
||||
logger.error("Failed to create registration descriptors")
|
||||
return None
|
||||
|
||||
registered_memory = self.agent.register_memory(reg_descs)
|
||||
if registered_memory:
|
||||
return registered_memory
|
||||
else:
|
||||
logger.error("Failed to register with NIXL")
|
||||
return None
|
||||
|
||||
return registered_memory # Could be None in case of error
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register {desc}: {e}")
|
||||
if not mem_type:
|
||||
logger.error(f"Failed to register Tensors with NIXL: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to register memory of type {mem_type} with NIXL: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def register_buffers(
|
||||
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Optional[Any]:
|
||||
"""Register tensors/buffers with NIXL."""
|
||||
if isinstance(buffers, torch.Tensor):
|
||||
buffers = [buffers]
|
||||
|
||||
if not buffers:
|
||||
return None
|
||||
|
||||
# Determine memory type based on tensor device
|
||||
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
|
||||
return self._register_memory(buffers, mem_type, "buffers")
|
||||
|
||||
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
|
||||
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
|
||||
return self._register_memory(tuples, "FILE", "files")
|
||||
|
||||
def register_objects(
|
||||
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
|
||||
) -> Optional[Any]:
|
||||
"""Register objects with NIXL."""
|
||||
if not keys:
|
||||
return None
|
||||
|
||||
# Create object tuples with proper sizes
|
||||
tuples = [
|
||||
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
||||
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
||||
]
|
||||
return self._register_memory(tuples, "OBJ", "objects")
|
||||
|
||||
|
||||
class NixlFileManager:
|
||||
"""Handles file system operations for NIXL."""
|
||||
@@ -221,12 +190,9 @@ class NixlFileManager:
|
||||
return False
|
||||
|
||||
def files_to_nixl_tuples(
|
||||
self, file_paths: List[str], open_file: bool = True
|
||||
self, file_paths: List[str]
|
||||
) -> List[Tuple[int, int, int, str]]:
|
||||
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
||||
if not open_file:
|
||||
return [(0, 0, 0, path) for path in file_paths]
|
||||
|
||||
tuples = []
|
||||
for path in file_paths:
|
||||
if (fd := self.open_file(path)) is None:
|
||||
|
||||
@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
||||
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
|
||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
||||
NixlFileManager,
|
||||
NixlRegistration,
|
||||
)
|
||||
|
||||
|
||||
class TestNixlUnified(unittest.TestCase):
|
||||
@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
|
||||
|
||||
# Test get
|
||||
retrieved = self.hicache.get(key, dst_tensor)
|
||||
self.verify_tensors_equal(value, dst_tensor)
|
||||
self.verify_tensors_equal(value, retrieved)
|
||||
|
||||
# Same test in addr,len mode with another key and dst_tensor
|
||||
key2 = "test_key2"
|
||||
dst_tensor2 = torch.zeros_like(value, device="cpu")
|
||||
src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
|
||||
dst_addr, dst_len = (
|
||||
dst_tensor2.data_ptr(),
|
||||
dst_tensor2.numel() * dst_tensor2.element_size(),
|
||||
)
|
||||
|
||||
# Test set
|
||||
self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
|
||||
self.assertTrue(self.hicache.exists(key))
|
||||
|
||||
# Test get
|
||||
retrieved2 = self.hicache.get(key, dst_addr, dst_len)
|
||||
self.assertTrue(retrieved2 == None)
|
||||
self.verify_tensors_equal(value, dst_tensor2)
|
||||
|
||||
def test_batch_set_get(self):
|
||||
"""Test batch tensor set/get operations."""
|
||||
keys = ["key1", "key2", "key3"]
|
||||
@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
|
||||
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
||||
self.verify_tensor_lists_equal(values, retrieved)
|
||||
|
||||
# Same test in addr,len mode with another key and dst_tensor
|
||||
keys2 = ["key4", "key5", "key6"]
|
||||
dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
|
||||
src_addrs = [v.data_ptr() for v in values]
|
||||
src_lens = [v.numel() * v.element_size() for v in values]
|
||||
dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
|
||||
dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
|
||||
|
||||
# Test batch set
|
||||
self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
|
||||
self.assertTrue(all(self.hicache.exists(key) for key in keys2))
|
||||
|
||||
# Test batch get
|
||||
retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
|
||||
self.assertTrue(all(ret == None for ret in retrieved2))
|
||||
self.verify_tensor_lists_equal(values, dst_tensors2)
|
||||
|
||||
def test_mixed_operations(self):
|
||||
"""Test mixing single and batch operations."""
|
||||
# Test interleaved set/get operations
|
||||
@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
|
||||
self.file_manager.create_file(test_file)
|
||||
|
||||
# Test tuple creation
|
||||
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
|
||||
tuples = self.file_manager.files_to_nixl_tuples([test_file])
|
||||
self.assertIsNotNone(tuples)
|
||||
self.assertTrue(len(tuples) > 0)
|
||||
|
||||
@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
|
||||
tensor = torch.randn(10, 10)
|
||||
|
||||
# Test buffer registration
|
||||
self.assertIsNotNone(self.registration.register_buffers(tensor))
|
||||
self.assertIsNotNone(self.hicache.register_buffers(tensor))
|
||||
|
||||
# Test batch registration
|
||||
tensors = [torch.randn(5, 5) for _ in range(3)]
|
||||
self.assertIsNotNone(self.registration.register_buffers(tensors))
|
||||
self.assertIsNotNone(self.hicache.register_buffers(tensors))
|
||||
|
||||
def test_register_files_with_tuples(self):
|
||||
"""Test registration of files using NIXL tuples."""
|
||||
@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
|
||||
self.file_manager.create_file(file)
|
||||
|
||||
# Create tuples and register
|
||||
tuples = self.file_manager.files_to_nixl_tuples(files, False)
|
||||
self.registration.register_files(tuples)
|
||||
tuples = self.file_manager.files_to_nixl_tuples(files)
|
||||
self.hicache.register_files(tuples)
|
||||
|
||||
# Verify tuples
|
||||
self.assertEqual(len(tuples), len(files))
|
||||
|
||||
Reference in New Issue
Block a user