From 7e880286b5d1cadf08bb0a1528fc93b5e54a7a84 Mon Sep 17 00:00:00 2001 From: Moein Khazraee <33970824+mkhazraee@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:06:13 -0700 Subject: [PATCH] Add support for extensions of interface and pre-registrations to NIXL HiCache (#9211) Co-authored-by: Zhiqiang Xie --- .../srt/mem_cache/storage/nixl/README.md | 32 ++- .../mem_cache/storage/nixl/hicache_nixl.py | 220 +++++++++++++----- .../srt/mem_cache/storage/nixl/nixl_utils.py | 74 ++---- .../storage/nixl/test_hicache_nixl_storage.py | 53 ++++- 4 files changed, 243 insertions(+), 136 deletions(-) diff --git a/python/sglang/srt/mem_cache/storage/nixl/README.md b/python/sglang/srt/mem_cache/storage/nixl/README.md index b00e0774e..d33cd5d05 100644 --- a/python/sglang/srt/mem_cache/storage/nixl/README.md +++ b/python/sglang/srt/mem_cache/storage/nixl/README.md @@ -36,6 +36,21 @@ Consolidated utility classes: - **NixlRegistration** - Manages memory registration for tensors, files and objects - **NixlFileManager** - Handles file system operations and NIXL tuple creation +## Using NIXL for HiCache backend +When running the SGLang server, indicate `nixl` for `hicache-storage-backend` parameter, for instance: + +```bash +python3 -m sglang.launch_server --model-path --host --port --page-size 64 --enable-hierarchical-cache --hicache-ratio 2 --hicache-size 64 --hicache-write-policy write_through --hicache-storage-backend nixl +``` + +To customize the base directory for files, you can set the following environment variable: + +```bash +export SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR=/path/to/desired/dir +``` + +Selection of any storage backend like 3FS requires availability of that library on the system, and the backend is selected based on the priority mentioned above. + ## Running Unit Tests ### Prerequisites @@ -43,33 +58,26 @@ Consolidated utility classes: - PyTorch installed - Python 3.8+ -### Unit tests from Project root -Navigate to the project root directory (`/path/to/sglang`) and run: +### Unit tests from current directory +From the current directory run: #### Run all NIXL tests: ```bash -PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict +PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict ``` #### Run with verbose output: ```bash -PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict +PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -o asyncio_mode=strict ``` Note: The `-v` flag provides more detailed output, showing each test case name and its result. #### Run a specific test: ```bash -PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict +PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict ``` -### From Tests Directory -Navigate to the tests directory and run: - -```bash -cd test/srt -PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict -``` Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output. ## Test Coverage diff --git a/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py b/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py index 35d8ec38a..327c90502 100644 --- a/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +++ b/python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py @@ -3,7 +3,7 @@ import logging import os import time import uuid -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage): def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"): """Initialize NIXL storage connector.""" + # Might be better to be unified across HiCache backends and moved to HiCacheController + file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path) self.file_manager = ( NixlFileManager(file_path) if plugin not in NixlBackendSelection.OBJ_PLUGINS @@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage): self.registration = NixlRegistration(self.agent) + def register_buffers( + self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]] + ) -> Optional[Any]: + """Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL.""" + if isinstance(buffers[0], tuple): + tuples = [(x[0], x[1], 0, "") for x in buffers] + return self.registration._register_memory(tuples, "DRAM") + else: + return self.registration._register_memory(buffers) + + def register_files( + self, file_paths: List[str], open_file: Optional[bool] = True + ) -> Optional[Any]: + """Register files with NIXL.""" + tuples = self.file_manager.files_to_nixl_tuples(file_paths) + return self.registration._register_memory(tuples, "FILE") + + def register_objects( + self, keys: List[str], sizes: Optional[List[int]] = None + ) -> Optional[Any]: + """Register objects with NIXL.""" + if not keys: + return None + tuples = [(0, 0, key, "") for key in keys] + return self.registration._register_memory(tuples, "OBJ") + def _execute_transfer( - self, tensors: List[torch.Tensor], keys: List[str], direction: str + self, + buffers: Optional[List[torch.Tensor | tuple]], + keys: List[str], + direction: str, ) -> bool: - if len(tensors) != len(keys): - logger.error("Mismatch between number of tensors and files/objects") + if len(buffers) != len(keys): + logger.error("Mismatch between number of tensors/buffers and files/objects") return False - if not self.registration.register_buffers(tensors): - logger.error("Failed to register tensors") - return False - - # Get transfer tuples based on backend type - tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors] + # Registering file and object keys per transfer, to be updated when + # pre-registration for file and object is added to HiCache. if self.backend_selector.mem_type == "FILE": - file_tuples = self.file_manager.files_to_nixl_tuples(keys) - if not file_tuples or not self.registration.register_files(file_tuples): + tuples = self.file_manager.files_to_nixl_tuples(keys) + if not tuples or not self.registration._register_memory(tuples, "FILE"): logger.error("Failed to prepare files for transfer") return False - transfer_tuples = [ - (x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes) - ] - else: - if not self.registration.register_objects(keys, tensors): + else: # mem_type == "OBJ" + tuples = [(0, 0, key, "") for key in keys] + if not tuples or not self.registration._register_memory(tuples, "OBJ"): logger.error("Failed to register objects") return False - transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)] + # Prepare transfer descriptors + if isinstance(buffers[0], torch.Tensor): + tensor_sizes = [ + tensor.element_size() * tensor.numel() for tensor in buffers + ] + storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)] + host_descs = self.agent.get_xfer_descs(buffers) + elif isinstance(buffers[0], tuple): + storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)] + host_descs = self.agent.get_xfer_descs( + [(x[0], x[1], 0) for x in buffers], "DRAM" + ) + else: + return False + + storage_descs = self.agent.get_xfer_descs( + storage_tuples, self.backend_selector.mem_type + ) + + if (host_descs is None) or (storage_descs is None): + logger.error("Failed to get transfer descriptors") + return False + + # Initialize transfer, default assumption that tensor was registered try: - # Get transfer descriptors - if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or ( - file_descs := self.agent.get_xfer_descs( - transfer_tuples, self.backend_selector.mem_type - ) - ) is None: - logger.error("Failed to get transfer descriptors") + xfer_req = self.agent.initialize_xfer( + direction, host_descs, storage_descs, self.agent_name + ) + except Exception: + # Check if it was due to missing pre-registration + if not self.register_buffers(buffers): + logger.error("Failed to register tensors/buffers") return False - # Initialize and execute transfer - if ( - xfer_req := self.agent.initialize_xfer( - direction, tensor_descs, file_descs, self.agent_name + try: + xfer_req = self.agent.initialize_xfer( + direction, host_descs, storage_descs, self.agent_name ) - ) is None: - logger.error("Failed to create transfer request") + except Exception as e: + logger.error(f"Failed to create transfer request: {e}") return False + # Execute transfer and wait for its completion + try: state = self.agent.transfer(xfer_req) while state != "DONE": state = self.agent.check_xfer_state(xfer_req) if state == "ERR": + self.agent.release_xfer_handle(xfer_req) logger.error("Transfer failed") return False - time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized + time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized + + self.agent.release_xfer_handle(xfer_req) return True except Exception as e: @@ -106,46 +158,88 @@ class HiCacheNixl(HiCacheStorage): logger.error(f"Traceback: {traceback.format_exc()}") return False - def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + def get( + self, + key: str, + target_location: Optional[torch.Tensor | int] = None, + target_sizes: Optional[int] = None, + ) -> torch.Tensor | None: + # To be removed, being compatible with the current API + if target_location is None: + return None + if target_sizes: + result = self.batch_get([key], [target_location], [target_sizes]) + else: + result = self.batch_get([key], [target_location]) + return result[0] if result else None + + def batch_get( + self, + keys: List[str], + target_locations: Optional[List[torch.Tensor | int]] = None, + target_sizes: Optional[List[int]] = None, + ) -> List[torch.Tensor | None]: if not keys: - return True + return [] + + # To be removed, being compatible with the current API + if not target_locations: + return [None] * len(keys) + + if target_sizes and (len(target_sizes) != len(target_locations)): + logger.error("Mismatch between number of target_locations and target_sizes") + return [None] * len(keys) + if target_sizes: + dest = list(zip(target_locations, target_sizes)) + else: + dest = target_locations + + if self.backend_selector.mem_type == "FILE": + file_paths = [self.file_manager.get_file_path(key) for key in keys] + success = self._execute_transfer(dest, file_paths, "READ") + else: + success = self._execute_transfer(dest, keys, "READ") + return target_locations if success and not target_sizes else [None] * len(keys) + + def set( + self, + key: str, + value: Optional[torch.Tensor] = None, + target_location: Optional[int] = None, + target_sizes: Optional[int] = None, + ) -> bool: + if target_location and target_sizes: + return self.batch_set([key], None, [target_location], [target_sizes]) + else: + return self.batch_set([key], [value]) + + def batch_set( + self, + keys: List[str], + values: Optional[List[torch.Tensor]] = None, + target_locations: Optional[List[int]] = None, + target_sizes: Optional[List[int]] = None, + ) -> bool: + if not keys or (not values and (not target_locations or not target_sizes)): + logger.error("Keys or values were not passed") + return False + + if not values: + values = list(zip(target_locations, target_sizes)) if self.backend_selector.mem_type == "FILE": file_paths = [] for key in keys: - tensor_path = self.file_manager.get_file_path(key) - if not self.file_manager.create_file(tensor_path): - logger.error(f"Failed to create file {tensor_path}") + file_path = self.file_manager.get_file_path(key) + # New file per set, to be updated when partial writes is added to HiCache + if not self.file_manager.create_file(file_path): + logger.error(f"Failed to create file {file_path}") return False - file_paths.append(tensor_path) + file_paths.append(file_path) return self._execute_transfer(values, file_paths, "WRITE") - else: + else: # mem_type == "OBJ" return self._execute_transfer(values, keys, "WRITE") - def set(self, key: str, value: torch.Tensor) -> bool: - return self.batch_set([key], [value]) - - def get( - self, key: str, dst_tensor: Optional[torch.Tensor] = None - ) -> torch.Tensor | None: - if dst_tensor is None: # To be removed, being compatible with the current API - return None - result = self.batch_get([key], [dst_tensor]) - return result[0] if result else None - - def batch_get( - self, keys: List[str], dst_tensors: List[torch.Tensor] - ) -> List[Optional[torch.Tensor]]: - if not keys: - return [] - - if self.backend_selector.mem_type == "FILE": - file_paths = [self.file_manager.get_file_path(key) for key in keys] - success = self._execute_transfer(dst_tensors, file_paths, "READ") - else: - success = self._execute_transfer(dst_tensors, keys, "READ") - return dst_tensors if success else [None] * len(keys) - def exists(self, key: str) -> bool: tuples = self.registration.create_query_tuples( key, diff --git a/python/sglang/srt/mem_cache/storage/nixl/nixl_utils.py b/python/sglang/srt/mem_cache/storage/nixl/nixl_utils.py index 476aed3a4..6e3d2a900 100644 --- a/python/sglang/srt/mem_cache/storage/nixl/nixl_utils.py +++ b/python/sglang/srt/mem_cache/storage/nixl/nixl_utils.py @@ -109,66 +109,35 @@ class NixlRegistration: return [(0, 0, key)] def _register_memory( - self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str + self, + items: Union[List[tuple], torch.Tensor, List[torch.Tensor]], + mem_type: Optional[str] = None, ) -> Optional[Any]: """Common registration logic for files, objects, and buffers. Args: items: List of tuples or tensors to register - mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM") - desc: Description for logging + mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors """ + if isinstance(items, list) and not items: + return None + + reg_descs = self.agent.get_reg_descs(items, mem_type) + if reg_descs is None: + logger.error("Failed to create registration descriptors") + return None + try: - if not items: - return None - - reg_descs = self.agent.get_reg_descs(items, mem_type) - if reg_descs is None: - logger.error("Failed to create registration descriptors") - return None - registered_memory = self.agent.register_memory(reg_descs) - if registered_memory: - return registered_memory - else: - logger.error("Failed to register with NIXL") - return None - + return registered_memory # Could be None in case of error except Exception as e: - logger.error(f"Failed to register {desc}: {e}") + if not mem_type: + logger.error(f"Failed to register Tensors with NIXL: {e}") + else: + logger.error( + f"Failed to register memory of type {mem_type} with NIXL: {e}" + ) return None - def register_buffers( - self, buffers: Union[torch.Tensor, List[torch.Tensor]] - ) -> Optional[Any]: - """Register tensors/buffers with NIXL.""" - if isinstance(buffers, torch.Tensor): - buffers = [buffers] - - if not buffers: - return None - - # Determine memory type based on tensor device - mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM" - return self._register_memory(buffers, mem_type, "buffers") - - def register_files(self, tuples: List[tuple]) -> Optional[Any]: - """Register files with NIXL using (0, 0, fd, file_path) tuples.""" - return self._register_memory(tuples, "FILE", "files") - - def register_objects( - self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None - ) -> Optional[Any]: - """Register objects with NIXL.""" - if not keys: - return None - - # Create object tuples with proper sizes - tuples = [ - (0, tensor.element_size() * tensor.numel() if tensor else 0, key) - for key, tensor in zip(keys, tensors or [None] * len(keys)) - ] - return self._register_memory(tuples, "OBJ", "objects") - class NixlFileManager: """Handles file system operations for NIXL.""" @@ -221,12 +190,9 @@ class NixlFileManager: return False def files_to_nixl_tuples( - self, file_paths: List[str], open_file: bool = True + self, file_paths: List[str] ) -> List[Tuple[int, int, int, str]]: """Create NIXL tuples (offset, length, fd, file_path) for given files.""" - if not open_file: - return [(0, 0, 0, path) for path in file_paths] - tuples = [] for path in file_paths: if (fd := self.open_file(path)) is None: diff --git a/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py b/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py index 572a032bf..951e5a4ea 100755 --- a/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +++ b/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py @@ -7,8 +7,11 @@ from unittest.mock import MagicMock import torch -from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl -from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration +from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl +from sglang.srt.mem_cache.storage.nixl.nixl_utils import ( + NixlFileManager, + NixlRegistration, +) class TestNixlUnified(unittest.TestCase): @@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase): # Test get retrieved = self.hicache.get(key, dst_tensor) + self.verify_tensors_equal(value, dst_tensor) self.verify_tensors_equal(value, retrieved) + # Same test in addr,len mode with another key and dst_tensor + key2 = "test_key2" + dst_tensor2 = torch.zeros_like(value, device="cpu") + src_addr, src_len = value.data_ptr(), value.numel() * value.element_size() + dst_addr, dst_len = ( + dst_tensor2.data_ptr(), + dst_tensor2.numel() * dst_tensor2.element_size(), + ) + + # Test set + self.assertTrue(self.hicache.set(key, None, src_addr, src_len)) + self.assertTrue(self.hicache.exists(key)) + + # Test get + retrieved2 = self.hicache.get(key, dst_addr, dst_len) + self.assertTrue(retrieved2 == None) + self.verify_tensors_equal(value, dst_tensor2) + def test_batch_set_get(self): """Test batch tensor set/get operations.""" keys = ["key1", "key2", "key3"] @@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase): retrieved = self.hicache.batch_get(keys, dst_tensors) self.verify_tensor_lists_equal(values, retrieved) + # Same test in addr,len mode with another key and dst_tensor + keys2 = ["key4", "key5", "key6"] + dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values] + src_addrs = [v.data_ptr() for v in values] + src_lens = [v.numel() * v.element_size() for v in values] + dst_addrs = [dt.data_ptr() for dt in dst_tensors2] + dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2] + + # Test batch set + self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens)) + self.assertTrue(all(self.hicache.exists(key) for key in keys2)) + + # Test batch get + retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens) + self.assertTrue(all(ret == None for ret in retrieved2)) + self.verify_tensor_lists_equal(values, dst_tensors2) + def test_mixed_operations(self): """Test mixing single and batch operations.""" # Test interleaved set/get operations @@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase): self.file_manager.create_file(test_file) # Test tuple creation - tuples = self.file_manager.files_to_nixl_tuples([test_file], False) + tuples = self.file_manager.files_to_nixl_tuples([test_file]) self.assertIsNotNone(tuples) self.assertTrue(len(tuples) > 0) @@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase): tensor = torch.randn(10, 10) # Test buffer registration - self.assertIsNotNone(self.registration.register_buffers(tensor)) + self.assertIsNotNone(self.hicache.register_buffers(tensor)) # Test batch registration tensors = [torch.randn(5, 5) for _ in range(3)] - self.assertIsNotNone(self.registration.register_buffers(tensors)) + self.assertIsNotNone(self.hicache.register_buffers(tensors)) def test_register_files_with_tuples(self): """Test registration of files using NIXL tuples.""" @@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase): self.file_manager.create_file(file) # Create tuples and register - tuples = self.file_manager.files_to_nixl_tuples(files, False) - self.registration.register_files(tuples) + tuples = self.file_manager.files_to_nixl_tuples(files) + self.hicache.register_files(tuples) # Verify tuples self.assertEqual(len(tuples), len(files))