integrate AIBrix KVcache (#10376)
This commit is contained in:
@@ -41,7 +41,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: codespell
|
- id: codespell
|
||||||
additional_dependencies: ['tomli']
|
additional_dependencies: ['tomli']
|
||||||
args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge']
|
args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge,PRIS']
|
||||||
exclude: |
|
exclude: |
|
||||||
(?x)^(
|
(?x)^(
|
||||||
test/srt/test_reasoning_parser\.py|
|
test/srt/test_reasoning_parser\.py|
|
||||||
|
|||||||
@@ -289,6 +289,14 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.storage_backend = MooncakeStore(self.storage_config)
|
self.storage_backend = MooncakeStore(self.storage_config)
|
||||||
|
elif storage_backend == "aibrix":
|
||||||
|
from sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage import (
|
||||||
|
AibrixKVCacheStorage,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.storage_backend = AibrixKVCacheStorage(
|
||||||
|
self.storage_config, self.mem_pool_host
|
||||||
|
)
|
||||||
elif storage_backend == "hf3fs":
|
elif storage_backend == "hf3fs":
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||||
HiCacheHF3FS,
|
HiCacheHF3FS,
|
||||||
|
|||||||
37
python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md
Normal file
37
python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# AIBrix KVCache as L3 KV Cache
|
||||||
|
This document provides brief instructions for setting up a AIBrixKVCache storage backend + AIBrixKVCache + SGLang runtime environment from scratch, describing how to utilize AIBrixKVCache as the L3 KV cache for SGLang.
|
||||||
|
The process consists of three main steps:
|
||||||
|
|
||||||
|
## Step1:Install AIbrix KVCache
|
||||||
|
Refer to the [AIBrix KVCache documentation](https://github.com/vllm-project/aibrix/blob/main/python/aibrix_kvcache/README.md) to install AIBrix KVCache.
|
||||||
|
|
||||||
|
## Step2: Deploy AIBrix Distributed KVCache Storage
|
||||||
|
|
||||||
|
AIBrix KVCache currently supports multiple distributed KVCache backends, including ByteDance's open-source Infinistore and the not-yet-open source PrisKV incubated by ByteDance's PrisDB & IAAS & DMI team.
|
||||||
|
|
||||||
|
For the Infinistore installation process, please refer to [this link](https://github.com/bytedance/InfiniStore).
|
||||||
|
|
||||||
|
PrisKV for AIBrix KVCache is currently in the open-source preparation stage, and no public documentation is available yet.
|
||||||
|
|
||||||
|
|
||||||
|
## Step3: Deploy Model Serving
|
||||||
|
|
||||||
|
For information on configuring a distributed KVCache backend for AIBrixKVCache, please refer to [this link](https://aibrix.readthedocs.io/latest/designs/aibrix-kvcache-offloading-framework.html)
|
||||||
|
|
||||||
|
Using PrisKV as an example, the startup command is as follows:
|
||||||
|
```bash
|
||||||
|
export AIBRIX_KV_CACHE_OL_L1_CACHE_ENABLED="0"
|
||||||
|
export AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND="PRIS"
|
||||||
|
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR="127.0.0.1"
|
||||||
|
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT="6379"
|
||||||
|
export AIBRIX_KV_CACHE_OL_PRIS_PASSWORD="kvcache-redis"
|
||||||
|
MODEL_LENGTH=32768&&NCCL_MIN_NCHANNELS=24&&NCCL_IB_QPS_PER_CONNECTION=8&&NCCL_DEBUG=INFO \
|
||||||
|
python3 -m sglang.launch_server \
|
||||||
|
--model-path /code/models/Qwen3-32B \
|
||||||
|
--host 0.0.0.0 --port 8080 \
|
||||||
|
--enable-hierarchical-cache \
|
||||||
|
--hicache-storage-backend aibrix \
|
||||||
|
--page-size 16 \
|
||||||
|
--hicache-write-policy write_back \
|
||||||
|
--enable-metrics --hicache-ratio=2
|
||||||
|
```
|
||||||
@@ -0,0 +1,151 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from aibrix_kvcache import (
|
||||||
|
BaseKVCacheManager,
|
||||||
|
BlockHashes,
|
||||||
|
KVCacheBlockLayout,
|
||||||
|
KVCacheBlockSpec,
|
||||||
|
KVCacheConfig,
|
||||||
|
KVCacheTensorSpec,
|
||||||
|
ModelSpec,
|
||||||
|
)
|
||||||
|
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
||||||
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AibrixKVCacheStorage(HiCacheStorage):
|
||||||
|
def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
|
||||||
|
if storage_config is not None:
|
||||||
|
self.is_mla_backend = storage_config.is_mla_model
|
||||||
|
self.local_rank = storage_config.tp_rank
|
||||||
|
else:
|
||||||
|
self.is_mla_backend = False
|
||||||
|
self.local_rank = 0
|
||||||
|
kv_cache = mem_pool.device_pool
|
||||||
|
self.page_size = mem_pool.page_size
|
||||||
|
self.kv_cache_dtype = kv_cache.dtype
|
||||||
|
self.layer_num = kv_cache.layer_num
|
||||||
|
self.kv_head_ids = [
|
||||||
|
self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
|
||||||
|
]
|
||||||
|
if not self.is_mla_backend:
|
||||||
|
self.layer_ids = range(
|
||||||
|
kv_cache.start_layer, kv_cache.end_layer
|
||||||
|
) # for pipeline parallel
|
||||||
|
|
||||||
|
self.block_spec = KVCacheBlockSpec(
|
||||||
|
block_ntokens=self.page_size,
|
||||||
|
block_dtype=self.kv_cache_dtype,
|
||||||
|
block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
|
||||||
|
tensor_spec=KVCacheTensorSpec(
|
||||||
|
heads=self.kv_head_ids,
|
||||||
|
layers=self.layer_ids,
|
||||||
|
head_size=kv_cache.head_dim,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(self.block_spec)
|
||||||
|
config = KVCacheConfig(
|
||||||
|
block_spec=self.block_spec, model_spec=ModelSpec(102400)
|
||||||
|
)
|
||||||
|
self.kv_cache_manager = BaseKVCacheManager(config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MLA is not supported by AibrixKVCacheStorage yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _aibrix_kvcache_metrics_report(self):
|
||||||
|
self.kv_cache_manager.metrics.summary()
|
||||||
|
self.kv_cache_manager.metrics.reset()
|
||||||
|
|
||||||
|
def batch_get(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
target_locations: List[torch.Tensor],
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> List[torch.Tensor | None]:
|
||||||
|
block_hash = BlockHashes(keys, self.page_size)
|
||||||
|
status = self.kv_cache_manager.acquire(None, block_hash)
|
||||||
|
log_every_n_seconds(
|
||||||
|
logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
|
||||||
|
)
|
||||||
|
if status.is_ok():
|
||||||
|
num_fetched_tokens, handle = status.value
|
||||||
|
kv_blocks = handle.to_tensors()
|
||||||
|
assert len(kv_blocks) == len(target_locations)
|
||||||
|
for i in range(len(kv_blocks)):
|
||||||
|
assert (
|
||||||
|
target_locations[i].nbytes == kv_blocks[i].nbytes
|
||||||
|
), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
|
||||||
|
target_locations[i].copy_(kv_blocks[i].flatten())
|
||||||
|
handle.release()
|
||||||
|
return target_locations
|
||||||
|
|
||||||
|
return [None] * len(keys)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_size: Optional[Any] = None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
return self.batch_get([key], [target_location], [target_size])[0]
|
||||||
|
|
||||||
|
def batch_set(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
values: Optional[Any] = None,
|
||||||
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
|
block_hash = BlockHashes(keys, self.page_size)
|
||||||
|
status = self.kv_cache_manager.allocate_for(None, block_hash)
|
||||||
|
if not status.is_ok():
|
||||||
|
logger.warning(
|
||||||
|
f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
handle = status.value
|
||||||
|
tensors = handle.to_tensors()
|
||||||
|
if len(tensors) != len(values):
|
||||||
|
logger.warning("aibrix_kvcache set allocate not enough")
|
||||||
|
return False
|
||||||
|
for i in range(len(tensors)):
|
||||||
|
assert (
|
||||||
|
tensors[i].nbytes == values[i].nbytes
|
||||||
|
), f"{tensors[i].nbytes}, {values[i].nbytes}"
|
||||||
|
tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
|
||||||
|
tensors[i].shape
|
||||||
|
)
|
||||||
|
status = self.kv_cache_manager.put(None, block_hash, handle)
|
||||||
|
if not status.is_ok():
|
||||||
|
logger.info(
|
||||||
|
f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
completed = status.value
|
||||||
|
return completed == len(keys) * self.page_size
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Optional[Any] = None,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_size: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
|
return self.batch_set([key], [value], [target_location], [target_size])
|
||||||
|
|
||||||
|
def batch_exists(self, keys: List[str]) -> int:
|
||||||
|
block_hash = BlockHashes(keys, self.page_size)
|
||||||
|
status = self.kv_cache_manager.exists(None, block_hash)
|
||||||
|
if status.is_ok():
|
||||||
|
return status.value // self.page_size
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def exists(self, key: str) -> bool | dict:
|
||||||
|
return self.batch_exists([key]) > 0
|
||||||
109
python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py
Normal file
109
python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from aibrix_kvcache import (
|
||||||
|
BaseKVCacheManager,
|
||||||
|
GroupAwareKVCacheManager,
|
||||||
|
KVCacheBlockLayout,
|
||||||
|
KVCacheBlockSpec,
|
||||||
|
KVCacheConfig,
|
||||||
|
KVCacheMetrics,
|
||||||
|
KVCacheTensorSpec,
|
||||||
|
ModelSpec,
|
||||||
|
TokenListView,
|
||||||
|
)
|
||||||
|
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
|
||||||
|
from aibrix_kvcache_storage import AibrixKVCacheStorage
|
||||||
|
from torch.distributed import Backend, ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||||
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||||
|
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
os.environ["RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = "63886"
|
||||||
|
|
||||||
|
|
||||||
|
class AIBrixKVCacheStorageTest:
|
||||||
|
def test_with_page_size(self):
|
||||||
|
config = HiCacheStorageConfig(
|
||||||
|
tp_rank=0,
|
||||||
|
tp_size=1,
|
||||||
|
is_mla_model=False,
|
||||||
|
is_page_first_layout=True,
|
||||||
|
model_name="test",
|
||||||
|
)
|
||||||
|
for page_size in range(1, 3):
|
||||||
|
logger.info(f"page_size: {page_size}")
|
||||||
|
batch_size = 2
|
||||||
|
head_num = 1
|
||||||
|
layer_num = 64
|
||||||
|
head_dim = 128
|
||||||
|
kv_cache = MHATokenToKVPool(
|
||||||
|
1024,
|
||||||
|
page_size,
|
||||||
|
torch.float16,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
layer_num,
|
||||||
|
"cpu",
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
layer_num,
|
||||||
|
)
|
||||||
|
mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
|
||||||
|
query_length = batch_size * 2
|
||||||
|
partial = batch_size
|
||||||
|
self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
|
||||||
|
target_shape = (2, layer_num, page_size, head_num, head_dim)
|
||||||
|
rand_tensor = [
|
||||||
|
torch.rand(target_shape, dtype=torch.float16)
|
||||||
|
for _ in range(query_length)
|
||||||
|
]
|
||||||
|
keys = ["hash" + str(i) for i in range(query_length)]
|
||||||
|
partial_keys = keys[batch_size:query_length]
|
||||||
|
assert self.aibrix_kvcache.batch_exists(keys) == 0
|
||||||
|
assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
|
||||||
|
get_tensor = [
|
||||||
|
torch.rand(target_shape, dtype=torch.float16).flatten()
|
||||||
|
for _ in range(query_length)
|
||||||
|
]
|
||||||
|
self.aibrix_kvcache.batch_get(keys, get_tensor)
|
||||||
|
for i in range(query_length):
|
||||||
|
assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
|
||||||
|
ret = self.aibrix_kvcache.batch_exists(keys)
|
||||||
|
assert self.aibrix_kvcache.batch_exists(keys) == query_length
|
||||||
|
assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
|
||||||
|
partial_get_tensor = [
|
||||||
|
torch.rand(target_shape, dtype=torch.float16).flatten()
|
||||||
|
for _ in range(partial)
|
||||||
|
]
|
||||||
|
self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
|
||||||
|
for i in range(partial):
|
||||||
|
assert torch.equal(
|
||||||
|
partial_get_tensor[i], rand_tensor[i + partial].flatten()
|
||||||
|
)
|
||||||
|
log_every_n_seconds(
|
||||||
|
logger,
|
||||||
|
logging.INFO,
|
||||||
|
self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup()
|
||||||
|
test = AIBrixKVCacheStorageTest()
|
||||||
|
test.test_with_page_size()
|
||||||
@@ -2154,7 +2154,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hicache-storage-backend",
|
"--hicache-storage-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["file", "mooncake", "hf3fs", "nixl"],
|
choices=["file", "mooncake", "hf3fs", "nixl", "aibrix"],
|
||||||
default=ServerArgs.hicache_storage_backend,
|
default=ServerArgs.hicache_storage_backend,
|
||||||
help="The storage backend for hierarchical KV cache.",
|
help="The storage backend for hierarchical KV cache.",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user