From fce170480af716482736e80217e56dd1f86a8352 Mon Sep 17 00:00:00 2001 From: yi wang Date: Thu, 25 Sep 2025 14:47:09 +0800 Subject: [PATCH] integrate AIBrix KVcache (#10376) --- .pre-commit-config.yaml | 2 +- .../sglang/srt/managers/cache_controller.py | 8 + .../storage/aibrix_kvcache/README.md | 37 +++++ .../aibrix_kvcache/aibrix_kvcache_storage.py | 151 ++++++++++++++++++ .../storage/aibrix_kvcache/unit_test.py | 109 +++++++++++++ python/sglang/srt/server_args.py | 2 +- 6 files changed, 307 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md create mode 100644 python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py create mode 100644 python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b4ff8072..f0fdbb5b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: hooks: - id: codespell additional_dependencies: ['tomli'] - args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge'] + args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge,PRIS'] exclude: | (?x)^( test/srt/test_reasoning_parser\.py| diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 041117753..3c5b497e1 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -289,6 +289,14 @@ class HiCacheController: ) self.storage_backend = MooncakeStore(self.storage_config) + elif storage_backend == "aibrix": + from sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage import ( + AibrixKVCacheStorage, + ) + + self.storage_backend = AibrixKVCacheStorage( + self.storage_config, self.mem_pool_host + ) elif storage_backend == "hf3fs": from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( HiCacheHF3FS, diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md new file mode 100644 index 000000000..16941967f --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md @@ -0,0 +1,37 @@ +# AIBrix KVCache as L3 KV Cache +This document provides brief instructions for setting up a AIBrixKVCache storage backend + AIBrixKVCache + SGLang runtime environment from scratch, describing how to utilize AIBrixKVCache as the L3 KV cache for SGLang. +The process consists of three main steps: + +## Step1:Install AIbrix KVCache +Refer to the [AIBrix KVCache documentation](https://github.com/vllm-project/aibrix/blob/main/python/aibrix_kvcache/README.md) to install AIBrix KVCache. + +## Step2: Deploy AIBrix Distributed KVCache Storage + +AIBrix KVCache currently supports multiple distributed KVCache backends, including ByteDance's open-source Infinistore and the not-yet-open source PrisKV incubated by ByteDance's PrisDB & IAAS & DMI team. + +For the Infinistore installation process, please refer to [this link](https://github.com/bytedance/InfiniStore). + +PrisKV for AIBrix KVCache is currently in the open-source preparation stage, and no public documentation is available yet. + + +## Step3: Deploy Model Serving + +For information on configuring a distributed KVCache backend for AIBrixKVCache, please refer to [this link](https://aibrix.readthedocs.io/latest/designs/aibrix-kvcache-offloading-framework.html) + +Using PrisKV as an example, the startup command is as follows: +```bash +export AIBRIX_KV_CACHE_OL_L1_CACHE_ENABLED="0" +export AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND="PRIS" +export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR="127.0.0.1" +export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT="6379" +export AIBRIX_KV_CACHE_OL_PRIS_PASSWORD="kvcache-redis" +MODEL_LENGTH=32768&&NCCL_MIN_NCHANNELS=24&&NCCL_IB_QPS_PER_CONNECTION=8&&NCCL_DEBUG=INFO \ +python3 -m sglang.launch_server \ + --model-path /code/models/Qwen3-32B \ + --host 0.0.0.0 --port 8080 \ + --enable-hierarchical-cache \ + --hicache-storage-backend aibrix \ + --page-size 16 \ + --hicache-write-policy write_back \ + --enable-metrics --hicache-ratio=2 +``` diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py new file mode 100644 index 000000000..59aacc11d --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py @@ -0,0 +1,151 @@ +import logging +from typing import Any, List, Optional + +import torch +from aibrix_kvcache import ( + BaseKVCacheManager, + BlockHashes, + KVCacheBlockLayout, + KVCacheBlockSpec, + KVCacheConfig, + KVCacheTensorSpec, + ModelSpec, +) +from aibrix_kvcache.common.absl_logging import log_every_n_seconds + +from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig +from sglang.srt.mem_cache.memory_pool_host import HostKVCache + +logger = logging.getLogger(__name__) + + +class AibrixKVCacheStorage(HiCacheStorage): + def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache): + if storage_config is not None: + self.is_mla_backend = storage_config.is_mla_model + self.local_rank = storage_config.tp_rank + else: + self.is_mla_backend = False + self.local_rank = 0 + kv_cache = mem_pool.device_pool + self.page_size = mem_pool.page_size + self.kv_cache_dtype = kv_cache.dtype + self.layer_num = kv_cache.layer_num + self.kv_head_ids = [ + self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num) + ] + if not self.is_mla_backend: + self.layer_ids = range( + kv_cache.start_layer, kv_cache.end_layer + ) # for pipeline parallel + + self.block_spec = KVCacheBlockSpec( + block_ntokens=self.page_size, + block_dtype=self.kv_cache_dtype, + block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD), + tensor_spec=KVCacheTensorSpec( + heads=self.kv_head_ids, + layers=self.layer_ids, + head_size=kv_cache.head_dim, + ), + ) + logger.info(self.block_spec) + config = KVCacheConfig( + block_spec=self.block_spec, model_spec=ModelSpec(102400) + ) + self.kv_cache_manager = BaseKVCacheManager(config) + else: + raise NotImplementedError( + "MLA is not supported by AibrixKVCacheStorage yet." + ) + + def _aibrix_kvcache_metrics_report(self): + self.kv_cache_manager.metrics.summary() + self.kv_cache_manager.metrics.reset() + + def batch_get( + self, + keys: List[str], + target_locations: List[torch.Tensor], + target_sizes: Optional[Any] = None, + ) -> List[torch.Tensor | None]: + block_hash = BlockHashes(keys, self.page_size) + status = self.kv_cache_manager.acquire(None, block_hash) + log_every_n_seconds( + logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1 + ) + if status.is_ok(): + num_fetched_tokens, handle = status.value + kv_blocks = handle.to_tensors() + assert len(kv_blocks) == len(target_locations) + for i in range(len(kv_blocks)): + assert ( + target_locations[i].nbytes == kv_blocks[i].nbytes + ), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}" + target_locations[i].copy_(kv_blocks[i].flatten()) + handle.release() + return target_locations + + return [None] * len(keys) + + def get( + self, + key: str, + target_location: Optional[Any] = None, + target_size: Optional[Any] = None, + ) -> torch.Tensor | None: + return self.batch_get([key], [target_location], [target_size])[0] + + def batch_set( + self, + keys: List[str], + values: Optional[Any] = None, + target_locations: Optional[Any] = None, + target_sizes: Optional[Any] = None, + ) -> bool: + block_hash = BlockHashes(keys, self.page_size) + status = self.kv_cache_manager.allocate_for(None, block_hash) + if not status.is_ok(): + logger.warning( + f"aibrix_kvcache set allocate failed, error_code {status.error_code}" + ) + return False + handle = status.value + tensors = handle.to_tensors() + if len(tensors) != len(values): + logger.warning("aibrix_kvcache set allocate not enough") + return False + for i in range(len(tensors)): + assert ( + tensors[i].nbytes == values[i].nbytes + ), f"{tensors[i].nbytes}, {values[i].nbytes}" + tensors[i].reshape(values[i].shape).copy_(values[i]).reshape( + tensors[i].shape + ) + status = self.kv_cache_manager.put(None, block_hash, handle) + if not status.is_ok(): + logger.info( + f"AIBrix KVCache Storage set failed, error_code {status.error_code}" + ) + return False + completed = status.value + return completed == len(keys) * self.page_size + + def set( + self, + key: str, + value: Optional[Any] = None, + target_location: Optional[Any] = None, + target_size: Optional[Any] = None, + ) -> bool: + return self.batch_set([key], [value], [target_location], [target_size]) + + def batch_exists(self, keys: List[str]) -> int: + block_hash = BlockHashes(keys, self.page_size) + status = self.kv_cache_manager.exists(None, block_hash) + if status.is_ok(): + return status.value // self.page_size + return 0 + + def exists(self, key: str) -> bool | dict: + return self.batch_exists([key]) > 0 diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py new file mode 100644 index 000000000..2e54e9816 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py @@ -0,0 +1,109 @@ +import logging +import os + +import torch +import torch.distributed +from aibrix_kvcache import ( + BaseKVCacheManager, + GroupAwareKVCacheManager, + KVCacheBlockLayout, + KVCacheBlockSpec, + KVCacheConfig, + KVCacheMetrics, + KVCacheTensorSpec, + ModelSpec, + TokenListView, +) +from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if +from aibrix_kvcache_storage import AibrixKVCacheStorage +from torch.distributed import Backend, ProcessGroup + +from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool +from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +logger = logging.getLogger(__name__) + + +def setup(): + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "63886" + + +class AIBrixKVCacheStorageTest: + def test_with_page_size(self): + config = HiCacheStorageConfig( + tp_rank=0, + tp_size=1, + is_mla_model=False, + is_page_first_layout=True, + model_name="test", + ) + for page_size in range(1, 3): + logger.info(f"page_size: {page_size}") + batch_size = 2 + head_num = 1 + layer_num = 64 + head_dim = 128 + kv_cache = MHATokenToKVPool( + 1024, + page_size, + torch.float16, + head_num, + head_dim, + layer_num, + "cpu", + False, + 0, + layer_num, + ) + mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first") + query_length = batch_size * 2 + partial = batch_size + self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool) + target_shape = (2, layer_num, page_size, head_num, head_dim) + rand_tensor = [ + torch.rand(target_shape, dtype=torch.float16) + for _ in range(query_length) + ] + keys = ["hash" + str(i) for i in range(query_length)] + partial_keys = keys[batch_size:query_length] + assert self.aibrix_kvcache.batch_exists(keys) == 0 + assert self.aibrix_kvcache.batch_set(keys, rand_tensor) + get_tensor = [ + torch.rand(target_shape, dtype=torch.float16).flatten() + for _ in range(query_length) + ] + self.aibrix_kvcache.batch_get(keys, get_tensor) + for i in range(query_length): + assert torch.equal(get_tensor[i], rand_tensor[i].flatten()) + ret = self.aibrix_kvcache.batch_exists(keys) + assert self.aibrix_kvcache.batch_exists(keys) == query_length + assert self.aibrix_kvcache.batch_exists(partial_keys) == partial + partial_get_tensor = [ + torch.rand(target_shape, dtype=torch.float16).flatten() + for _ in range(partial) + ] + self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor) + for i in range(partial): + assert torch.equal( + partial_get_tensor[i], rand_tensor[i + partial].flatten() + ) + log_every_n_seconds( + logger, + logging.INFO, + self.aibrix_kvcache.kv_cache_manager.metrics.summary(), + 1, + ) + + +if __name__ == "__main__": + setup() + test = AIBrixKVCacheStorageTest() + test.test_with_page_size() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 71ddcbf3e..d5d1cecf2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2154,7 +2154,7 @@ class ServerArgs: parser.add_argument( "--hicache-storage-backend", type=str, - choices=["file", "mooncake", "hf3fs", "nixl"], + choices=["file", "mooncake", "hf3fs", "nixl", "aibrix"], default=ServerArgs.hicache_storage_backend, help="The storage backend for hierarchical KV cache.", )