[Feature] LMCache Connector Integration (#9741)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com> Signed-off-by: YuhanLiu11 <yliu738@wisc.edu> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -656,6 +656,21 @@ class Scheduler(
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif server_args.enable_lmcache:
|
||||
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||
LMCRadixCache,
|
||||
)
|
||||
|
||||
self.tree_cache = LMCRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
model_config=self.model_config,
|
||||
tp_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
tp_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
@@ -1411,9 +1426,11 @@ class Scheduler(
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = (available_size + evictable_size) != (
|
||||
# self.max_total_num_tokens
|
||||
# if not self.enable_hierarchical_cache
|
||||
# else self.max_total_num_tokens - protected_size
|
||||
self.max_total_num_tokens
|
||||
if not self.enable_hierarchical_cache
|
||||
else self.max_total_num_tokens - protected_size
|
||||
- protected_size
|
||||
)
|
||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||
|
||||
|
||||
@@ -369,7 +369,6 @@ class MHATokenToKVPool(KVCache):
|
||||
# same applies to get_value_buffer and get_kv_buffer
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
return self._get_key_buffer(layer_id)
|
||||
|
||||
def _get_value_buffer(self, layer_id: int):
|
||||
|
||||
43
python/sglang/srt/mem_cache/storage/lmcache/README.md
Normal file
43
python/sglang/srt/mem_cache/storage/lmcache/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# LMCache Connector for SGLang
|
||||
|
||||
This document describes how to use LMCache as KV Cache Management Backend for SGLang engine.
|
||||
For more details about LMCache, please refer to: https://lmcache.ai
|
||||
|
||||
## Install LMCache
|
||||
|
||||
### Method 1: with pip
|
||||
|
||||
```bash
|
||||
pip install lmcache
|
||||
```
|
||||
|
||||
### Method 2: from source
|
||||
|
||||
Clone LMCache project:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/LMCache/LMCache
|
||||
```
|
||||
|
||||
Install:
|
||||
|
||||
```bash
|
||||
cd LMCache
|
||||
pip install -e . --no-build-isolation
|
||||
```
|
||||
|
||||
|
||||
## Use LMCache
|
||||
|
||||
Firstly, setup LMCache config. An example config is set at `example_config.yaml`. For more settings please refer to https://docs.lmcache.ai/api_reference/configurations.html.
|
||||
|
||||
Secondly, setup SGLang serving engine with lmcache:
|
||||
|
||||
```bash
|
||||
export LMCACHE_USE_EXPERIMENTAL=True
|
||||
export LMCACHE_CONFIG_FILE=example_config.yaml
|
||||
|
||||
python -m sglang.launch_server \
|
||||
--model-path MODEL \
|
||||
--enable-lmcache
|
||||
```
|
||||
@@ -0,0 +1,7 @@
|
||||
# Basic configurations
|
||||
chunk_size: 256
|
||||
|
||||
# CPU offloading configurations
|
||||
local_cpu: true
|
||||
use_layerwise: true
|
||||
max_local_cpu_size: 10 # number of CPU backend GB
|
||||
280
python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py
Normal file
280
python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
try:
|
||||
from lmcache.integration.sglang.sglang_adapter import (
|
||||
LMCacheLayerwiseConnector,
|
||||
LoadMetadata,
|
||||
StoreMetadata,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"LMCache is not installed. Please install it by running `pip install lmcache`"
|
||||
) from e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayerTransferCounter:
|
||||
"""Minimal adapter that lets the memory pool notify LMCache per-layer.
|
||||
|
||||
The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
|
||||
translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
|
||||
within the provided CUDA stream.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
load_stream: torch.cuda.Stream,
|
||||
lmc_connector: LMCacheLayerwiseConnector,
|
||||
printable: bool = False,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.load_stream = load_stream
|
||||
self.lmc_connector = lmc_connector
|
||||
|
||||
def wait_until(self, layer_id: int):
|
||||
# Ensure ordering of the async loads wrt compute stream(s).
|
||||
self.load_stream.synchronize()
|
||||
with self.load_stream:
|
||||
self.lmc_connector.load_kv_layerwise(layer_id)
|
||||
|
||||
|
||||
class LMCRadixCache(RadixCache):
|
||||
"""RadixCache + LMCache IO.
|
||||
|
||||
This subclass adds:
|
||||
- LMCache connector setup (device/host buffers, TP rank/size)
|
||||
- Two CUDA streams for async load/store
|
||||
- Layer-wise transfer executor wiring to the KV cache
|
||||
- Overridden `match_prefix` to fetch missing prefix chunks from LMCache
|
||||
- Extended cache_finalization paths to store back into LMCache
|
||||
- Eviction barrier that respects any in-flight host->device stores
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
model_config: Optional["ModelConfig"] = None,
|
||||
tp_size: int = 1,
|
||||
rank: int = 0,
|
||||
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
):
|
||||
super().__init__(
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
page_size=page_size,
|
||||
disable=disable,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
|
||||
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||
self.lmcache_connector = LMCacheLayerwiseConnector(
|
||||
sgl_config=model_config,
|
||||
tp_size=tp_size,
|
||||
rank=rank,
|
||||
# NOTE: The original implementation accessed private buffers via
|
||||
# `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
|
||||
# available; fall back to private fields if needed.
|
||||
k_pool=getattr(
|
||||
kvcache,
|
||||
"k_buffer",
|
||||
getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
|
||||
),
|
||||
v_pool=getattr(
|
||||
kvcache,
|
||||
"v_buffer",
|
||||
getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
|
||||
),
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.load_stream = torch.cuda.Stream()
|
||||
self.store_stream = torch.cuda.Stream()
|
||||
|
||||
self.layer_done_executor = LayerTransferCounter(
|
||||
num_layers=(
|
||||
model_config.num_hidden_layers if model_config is not None else 0
|
||||
),
|
||||
load_stream=self.load_stream,
|
||||
lmc_connector=self.lmcache_connector,
|
||||
)
|
||||
kvcache.register_layer_transfer_counter(self.layer_done_executor)
|
||||
|
||||
self._in_flight_nodes: list[TreeNode] = []
|
||||
self._node_lock = threading.Lock()
|
||||
|
||||
def reset(self): # type: ignore[override]
|
||||
super().reset()
|
||||
if hasattr(self, "_in_flight_nodes"):
|
||||
with self._node_lock:
|
||||
self._in_flight_nodes.clear()
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
|
||||
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
||||
|
||||
Reuses the base matching logic to obtain (value, last_node). If there
|
||||
remains a *page-aligned* uncached suffix and there is room (or after
|
||||
eviction), we allocate token slots and trigger an async LMCache load
|
||||
into those slots, then materialize a new child node for the retrieved
|
||||
chunk.
|
||||
"""
|
||||
if self.disable or not key:
|
||||
return super().match_prefix(key, **kwargs)
|
||||
|
||||
if self.page_size != 1:
|
||||
aligned_len = len(key) // self.page_size * self.page_size
|
||||
key = key[:aligned_len]
|
||||
|
||||
base_res = super().match_prefix(key, **kwargs)
|
||||
value: torch.Tensor = base_res.device_indices
|
||||
last_node: TreeNode = base_res.last_device_node
|
||||
|
||||
if value.numel() == len(key):
|
||||
return base_res
|
||||
|
||||
uncached_len = len(key) - value.numel()
|
||||
if uncached_len == 0:
|
||||
return base_res
|
||||
|
||||
chunk_size = self.lmcache_connector.chunk_size()
|
||||
prefix_pad = value.numel() % chunk_size
|
||||
|
||||
if self.token_to_kv_pool_allocator.available_size() < uncached_len:
|
||||
self.evict(uncached_len)
|
||||
|
||||
token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
|
||||
if token_slots is None:
|
||||
return base_res
|
||||
|
||||
slot_mapping = torch.cat(
|
||||
[
|
||||
torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
|
||||
token_slots.detach().clone().to(torch.int64).to(self.device),
|
||||
]
|
||||
)
|
||||
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
num_retrieved = self.lmcache_connector.start_load_kv(
|
||||
LoadMetadata(
|
||||
token_ids=key, # full page-aligned key
|
||||
slot_mapping=slot_mapping,
|
||||
offset=value.numel() - prefix_pad, # LMCache offset convention
|
||||
)
|
||||
)
|
||||
logger.debug("num_retrieved_tokens: %s", num_retrieved)
|
||||
|
||||
if num_retrieved > 0:
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
token_slots[(num_retrieved - prefix_pad) :]
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator.free(token_slots)
|
||||
|
||||
if num_retrieved > 0:
|
||||
fetched = num_retrieved - prefix_pad
|
||||
new_node = TreeNode()
|
||||
start = value.numel()
|
||||
end = start + fetched
|
||||
new_node.key = key[start:end]
|
||||
new_node.value = token_slots[:fetched]
|
||||
new_node.parent = last_node
|
||||
last_node.children[self.get_child_key_fn(new_node.key)] = new_node
|
||||
last_node = new_node
|
||||
|
||||
value = torch.cat([value, token_slots[:fetched]])
|
||||
self.evictable_size_ += fetched
|
||||
|
||||
self._record_store_event(new_node.parent)
|
||||
self._record_store_event(new_node)
|
||||
|
||||
return MatchResult(
|
||||
device_indices=value,
|
||||
last_device_node=last_node,
|
||||
last_host_node=last_node,
|
||||
)
|
||||
|
||||
return base_res
|
||||
|
||||
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
||||
"""On request completion, insert device KV into radix and store to LMCache."""
|
||||
|
||||
super().cache_finished_req(req)
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
||||
assert new_last_node is not None
|
||||
|
||||
self.inc_lock_ref(new_last_node)
|
||||
store_md = StoreMetadata(
|
||||
last_node=new_last_node,
|
||||
token_ids=token_ids,
|
||||
kv_indices=kv_indices,
|
||||
offset=0,
|
||||
)
|
||||
with torch.cuda.stream(self.store_stream):
|
||||
self.lmcache_connector.store_kv(store_md)
|
||||
with self._node_lock:
|
||||
self._in_flight_nodes.append(new_last_node)
|
||||
|
||||
def evict(self, num_tokens: int) -> None: # type: ignore[override]
|
||||
"""Before base eviction, wait for any outstanding stores and release locks."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
self.store_stream.synchronize()
|
||||
with self._node_lock:
|
||||
for node in self._in_flight_nodes:
|
||||
self.dec_lock_ref(node)
|
||||
self._in_flight_nodes.clear()
|
||||
|
||||
super().evict(num_tokens)
|
||||
|
||||
def pretty_print(self): # type: ignore[override]
|
||||
super().pretty_print()
|
||||
try:
|
||||
logger.debug(
|
||||
"evictable=%d protected=%d", self.evictable_size_, self.protected_size_
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cache = LMCRadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=1,
|
||||
disable=False,
|
||||
enable_kv_cache_events=False,
|
||||
model_config=None,
|
||||
tp_size=1,
|
||||
rank=0,
|
||||
tp_group=None,
|
||||
)
|
||||
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
||||
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
|
||||
cache.pretty_print()
|
||||
121
python/sglang/srt/mem_cache/storage/lmcache/unit_test.py
Normal file
121
python/sglang/srt/mem_cache/storage/lmcache/unit_test.py
Normal file
@@ -0,0 +1,121 @@
|
||||
try:
|
||||
from lmcache.integration.sglang.sglang_adapter import (
|
||||
LMCacheLayerwiseConnector,
|
||||
LoadMetadata,
|
||||
StoreMetadata,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
|
||||
)
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
|
||||
|
||||
|
||||
def test_load_store_metadata():
|
||||
model_config = ModelConfig(
|
||||
model_path="Qwen/Qwen3-4B",
|
||||
)
|
||||
|
||||
# Generate Dummy KV Cache
|
||||
head_num = model_config.num_key_value_heads
|
||||
head_dim = model_config.head_dim
|
||||
layer_num = model_config.num_hidden_layers
|
||||
buffer_size = 256
|
||||
input_id_len = 16
|
||||
|
||||
k_buffer = [
|
||||
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
v_buffer = [
|
||||
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
|
||||
|
||||
fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
|
||||
fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
|
||||
offset = 0
|
||||
|
||||
store_metadata = StoreMetadata(
|
||||
last_node=None,
|
||||
token_ids=fake_token_ids,
|
||||
kv_indices=fake_kv_indices,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
load_metadata = LoadMetadata(
|
||||
token_ids=fake_token_ids,
|
||||
slot_mapping=fake_kv_indices,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
retrieve_token_num = connector.start_load_kv(load_metadata)
|
||||
assert retrieve_token_num == 0
|
||||
|
||||
connector.store_kv(store_metadata)
|
||||
current_stream.synchronize()
|
||||
|
||||
# check retrieve
|
||||
gt_key_buffer = [
|
||||
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
gt_value_buffer = [
|
||||
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
for i in range(layer_num):
|
||||
gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
||||
gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
||||
|
||||
# clear the k_buffer and v_buffer
|
||||
for _ in range(layer_num):
|
||||
k_buffer[i].zero_()
|
||||
v_buffer[i].zero_()
|
||||
|
||||
retrieve_token_num = connector.start_load_kv(load_metadata)
|
||||
assert retrieve_token_num == input_id_len
|
||||
|
||||
for i in range(layer_num):
|
||||
current_stream.synchronize()
|
||||
connector.load_kv_layerwise(i)
|
||||
|
||||
current_stream.synchronize()
|
||||
test_key_buffer = [
|
||||
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
test_value_buffer = [
|
||||
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
for i in range(layer_num):
|
||||
test_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
||||
test_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
||||
|
||||
for i in range(layer_num):
|
||||
assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
|
||||
assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
|
||||
|
||||
print("================================================")
|
||||
print("TEST_LOAD_STORE_METADATA PASSED!")
|
||||
print("================================================")
|
||||
connector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_store_metadata()
|
||||
@@ -303,6 +303,8 @@ class ServerArgs:
|
||||
hicache_storage_backend: Optional[str] = None
|
||||
hicache_storage_prefetch_policy: str = "best_effort"
|
||||
hicache_storage_backend_extra_config: Optional[str] = None
|
||||
# LMCache
|
||||
enable_lmcache: bool = False
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
@@ -1735,6 +1737,12 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_storage_backend_extra_config,
|
||||
help="A dictionary in JSON string format containing extra configuration for the storage backend.",
|
||||
)
|
||||
# LMCache
|
||||
parser.add_argument(
|
||||
"--enable-lmcache",
|
||||
action="store_true",
|
||||
help="Using LMCache as an alternative hierarchical cache solution",
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user