integrate AIBrix KVcache (#10376)
This commit is contained in:
@@ -41,7 +41,7 @@ repos:
|
||||
hooks:
|
||||
- id: codespell
|
||||
additional_dependencies: ['tomli']
|
||||
args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge']
|
||||
args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge,PRIS']
|
||||
exclude: |
|
||||
(?x)^(
|
||||
test/srt/test_reasoning_parser\.py|
|
||||
|
||||
@@ -289,6 +289,14 @@ class HiCacheController:
|
||||
)
|
||||
|
||||
self.storage_backend = MooncakeStore(self.storage_config)
|
||||
elif storage_backend == "aibrix":
|
||||
from sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage import (
|
||||
AibrixKVCacheStorage,
|
||||
)
|
||||
|
||||
self.storage_backend = AibrixKVCacheStorage(
|
||||
self.storage_config, self.mem_pool_host
|
||||
)
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||
HiCacheHF3FS,
|
||||
|
||||
37
python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md
Normal file
37
python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# AIBrix KVCache as L3 KV Cache
|
||||
This document provides brief instructions for setting up a AIBrixKVCache storage backend + AIBrixKVCache + SGLang runtime environment from scratch, describing how to utilize AIBrixKVCache as the L3 KV cache for SGLang.
|
||||
The process consists of three main steps:
|
||||
|
||||
## Step1:Install AIbrix KVCache
|
||||
Refer to the [AIBrix KVCache documentation](https://github.com/vllm-project/aibrix/blob/main/python/aibrix_kvcache/README.md) to install AIBrix KVCache.
|
||||
|
||||
## Step2: Deploy AIBrix Distributed KVCache Storage
|
||||
|
||||
AIBrix KVCache currently supports multiple distributed KVCache backends, including ByteDance's open-source Infinistore and the not-yet-open source PrisKV incubated by ByteDance's PrisDB & IAAS & DMI team.
|
||||
|
||||
For the Infinistore installation process, please refer to [this link](https://github.com/bytedance/InfiniStore).
|
||||
|
||||
PrisKV for AIBrix KVCache is currently in the open-source preparation stage, and no public documentation is available yet.
|
||||
|
||||
|
||||
## Step3: Deploy Model Serving
|
||||
|
||||
For information on configuring a distributed KVCache backend for AIBrixKVCache, please refer to [this link](https://aibrix.readthedocs.io/latest/designs/aibrix-kvcache-offloading-framework.html)
|
||||
|
||||
Using PrisKV as an example, the startup command is as follows:
|
||||
```bash
|
||||
export AIBRIX_KV_CACHE_OL_L1_CACHE_ENABLED="0"
|
||||
export AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND="PRIS"
|
||||
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR="127.0.0.1"
|
||||
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT="6379"
|
||||
export AIBRIX_KV_CACHE_OL_PRIS_PASSWORD="kvcache-redis"
|
||||
MODEL_LENGTH=32768&&NCCL_MIN_NCHANNELS=24&&NCCL_IB_QPS_PER_CONNECTION=8&&NCCL_DEBUG=INFO \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path /code/models/Qwen3-32B \
|
||||
--host 0.0.0.0 --port 8080 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-storage-backend aibrix \
|
||||
--page-size 16 \
|
||||
--hicache-write-policy write_back \
|
||||
--enable-metrics --hicache-ratio=2
|
||||
```
|
||||
@@ -0,0 +1,151 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from aibrix_kvcache import (
|
||||
BaseKVCacheManager,
|
||||
BlockHashes,
|
||||
KVCacheBlockLayout,
|
||||
KVCacheBlockSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheTensorSpec,
|
||||
ModelSpec,
|
||||
)
|
||||
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AibrixKVCacheStorage(HiCacheStorage):
|
||||
def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
|
||||
if storage_config is not None:
|
||||
self.is_mla_backend = storage_config.is_mla_model
|
||||
self.local_rank = storage_config.tp_rank
|
||||
else:
|
||||
self.is_mla_backend = False
|
||||
self.local_rank = 0
|
||||
kv_cache = mem_pool.device_pool
|
||||
self.page_size = mem_pool.page_size
|
||||
self.kv_cache_dtype = kv_cache.dtype
|
||||
self.layer_num = kv_cache.layer_num
|
||||
self.kv_head_ids = [
|
||||
self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
|
||||
]
|
||||
if not self.is_mla_backend:
|
||||
self.layer_ids = range(
|
||||
kv_cache.start_layer, kv_cache.end_layer
|
||||
) # for pipeline parallel
|
||||
|
||||
self.block_spec = KVCacheBlockSpec(
|
||||
block_ntokens=self.page_size,
|
||||
block_dtype=self.kv_cache_dtype,
|
||||
block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
|
||||
tensor_spec=KVCacheTensorSpec(
|
||||
heads=self.kv_head_ids,
|
||||
layers=self.layer_ids,
|
||||
head_size=kv_cache.head_dim,
|
||||
),
|
||||
)
|
||||
logger.info(self.block_spec)
|
||||
config = KVCacheConfig(
|
||||
block_spec=self.block_spec, model_spec=ModelSpec(102400)
|
||||
)
|
||||
self.kv_cache_manager = BaseKVCacheManager(config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"MLA is not supported by AibrixKVCacheStorage yet."
|
||||
)
|
||||
|
||||
def _aibrix_kvcache_metrics_report(self):
|
||||
self.kv_cache_manager.metrics.summary()
|
||||
self.kv_cache_manager.metrics.reset()
|
||||
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: List[torch.Tensor],
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
block_hash = BlockHashes(keys, self.page_size)
|
||||
status = self.kv_cache_manager.acquire(None, block_hash)
|
||||
log_every_n_seconds(
|
||||
logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
|
||||
)
|
||||
if status.is_ok():
|
||||
num_fetched_tokens, handle = status.value
|
||||
kv_blocks = handle.to_tensors()
|
||||
assert len(kv_blocks) == len(target_locations)
|
||||
for i in range(len(kv_blocks)):
|
||||
assert (
|
||||
target_locations[i].nbytes == kv_blocks[i].nbytes
|
||||
), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
|
||||
target_locations[i].copy_(kv_blocks[i].flatten())
|
||||
handle.release()
|
||||
return target_locations
|
||||
|
||||
return [None] * len(keys)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_size: Optional[Any] = None,
|
||||
) -> torch.Tensor | None:
|
||||
return self.batch_get([key], [target_location], [target_size])[0]
|
||||
|
||||
def batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
values: Optional[Any] = None,
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
block_hash = BlockHashes(keys, self.page_size)
|
||||
status = self.kv_cache_manager.allocate_for(None, block_hash)
|
||||
if not status.is_ok():
|
||||
logger.warning(
|
||||
f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
|
||||
)
|
||||
return False
|
||||
handle = status.value
|
||||
tensors = handle.to_tensors()
|
||||
if len(tensors) != len(values):
|
||||
logger.warning("aibrix_kvcache set allocate not enough")
|
||||
return False
|
||||
for i in range(len(tensors)):
|
||||
assert (
|
||||
tensors[i].nbytes == values[i].nbytes
|
||||
), f"{tensors[i].nbytes}, {values[i].nbytes}"
|
||||
tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
|
||||
tensors[i].shape
|
||||
)
|
||||
status = self.kv_cache_manager.put(None, block_hash, handle)
|
||||
if not status.is_ok():
|
||||
logger.info(
|
||||
f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
|
||||
)
|
||||
return False
|
||||
completed = status.value
|
||||
return completed == len(keys) * self.page_size
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[Any] = None,
|
||||
target_location: Optional[Any] = None,
|
||||
target_size: Optional[Any] = None,
|
||||
) -> bool:
|
||||
return self.batch_set([key], [value], [target_location], [target_size])
|
||||
|
||||
def batch_exists(self, keys: List[str]) -> int:
|
||||
block_hash = BlockHashes(keys, self.page_size)
|
||||
status = self.kv_cache_manager.exists(None, block_hash)
|
||||
if status.is_ok():
|
||||
return status.value // self.page_size
|
||||
return 0
|
||||
|
||||
def exists(self, key: str) -> bool | dict:
|
||||
return self.batch_exists([key]) > 0
|
||||
109
python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py
Normal file
109
python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from aibrix_kvcache import (
|
||||
BaseKVCacheManager,
|
||||
GroupAwareKVCacheManager,
|
||||
KVCacheBlockLayout,
|
||||
KVCacheBlockSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheMetrics,
|
||||
KVCacheTensorSpec,
|
||||
ModelSpec,
|
||||
TokenListView,
|
||||
)
|
||||
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
|
||||
from aibrix_kvcache_storage import AibrixKVCacheStorage
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup():
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "63886"
|
||||
|
||||
|
||||
class AIBrixKVCacheStorageTest:
|
||||
def test_with_page_size(self):
|
||||
config = HiCacheStorageConfig(
|
||||
tp_rank=0,
|
||||
tp_size=1,
|
||||
is_mla_model=False,
|
||||
is_page_first_layout=True,
|
||||
model_name="test",
|
||||
)
|
||||
for page_size in range(1, 3):
|
||||
logger.info(f"page_size: {page_size}")
|
||||
batch_size = 2
|
||||
head_num = 1
|
||||
layer_num = 64
|
||||
head_dim = 128
|
||||
kv_cache = MHATokenToKVPool(
|
||||
1024,
|
||||
page_size,
|
||||
torch.float16,
|
||||
head_num,
|
||||
head_dim,
|
||||
layer_num,
|
||||
"cpu",
|
||||
False,
|
||||
0,
|
||||
layer_num,
|
||||
)
|
||||
mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
|
||||
query_length = batch_size * 2
|
||||
partial = batch_size
|
||||
self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
|
||||
target_shape = (2, layer_num, page_size, head_num, head_dim)
|
||||
rand_tensor = [
|
||||
torch.rand(target_shape, dtype=torch.float16)
|
||||
for _ in range(query_length)
|
||||
]
|
||||
keys = ["hash" + str(i) for i in range(query_length)]
|
||||
partial_keys = keys[batch_size:query_length]
|
||||
assert self.aibrix_kvcache.batch_exists(keys) == 0
|
||||
assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
|
||||
get_tensor = [
|
||||
torch.rand(target_shape, dtype=torch.float16).flatten()
|
||||
for _ in range(query_length)
|
||||
]
|
||||
self.aibrix_kvcache.batch_get(keys, get_tensor)
|
||||
for i in range(query_length):
|
||||
assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
|
||||
ret = self.aibrix_kvcache.batch_exists(keys)
|
||||
assert self.aibrix_kvcache.batch_exists(keys) == query_length
|
||||
assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
|
||||
partial_get_tensor = [
|
||||
torch.rand(target_shape, dtype=torch.float16).flatten()
|
||||
for _ in range(partial)
|
||||
]
|
||||
self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
|
||||
for i in range(partial):
|
||||
assert torch.equal(
|
||||
partial_get_tensor[i], rand_tensor[i + partial].flatten()
|
||||
)
|
||||
log_every_n_seconds(
|
||||
logger,
|
||||
logging.INFO,
|
||||
self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
||||
test = AIBrixKVCacheStorageTest()
|
||||
test.test_with_page_size()
|
||||
@@ -2154,7 +2154,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
choices=["file", "mooncake", "hf3fs", "nixl"],
|
||||
choices=["file", "mooncake", "hf3fs", "nixl", "aibrix"],
|
||||
default=ServerArgs.hicache_storage_backend,
|
||||
help="The storage backend for hierarchical KV cache.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user