From 9a7ced4e4dea7647e4b5aee098b8b19b96cd2c8b Mon Sep 17 00:00:00 2001 From: Yuwei An Date: Sat, 6 Sep 2025 20:14:55 -0700 Subject: [PATCH] [Feature] LMCache Connector Integration (#9741) Signed-off-by: Oasis-Git Signed-off-by: YuhanLiu11 Co-authored-by: Zhiqiang Xie --- python/sglang/srt/managers/scheduler.py | 21 +- python/sglang/srt/mem_cache/memory_pool.py | 1 - .../srt/mem_cache/storage/lmcache/README.md | 43 +++ .../storage/lmcache/example_config.yaml | 7 + .../storage/lmcache/lmc_radix_cache.py | 280 ++++++++++++++++++ .../mem_cache/storage/lmcache/unit_test.py | 121 ++++++++ python/sglang/srt/server_args.py | 8 + 7 files changed, 478 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/mem_cache/storage/lmcache/README.md create mode 100644 python/sglang/srt/mem_cache/storage/lmcache/example_config.yaml create mode 100644 python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py create mode 100644 python/sglang/srt/mem_cache/storage/lmcache/unit_test.py diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2dbc63191..8daa8afe2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -656,6 +656,21 @@ class Scheduler( page_size=self.page_size, disable=server_args.disable_radix_cache, ) + elif server_args.enable_lmcache: + from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( + LMCRadixCache, + ) + + self.tree_cache = LMCRadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, + disable=server_args.disable_radix_cache, + model_config=self.model_config, + tp_size=self.tp_size, + rank=self.tp_rank, + tp_group=self.tp_group, + ) else: self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, @@ -1411,9 +1426,11 @@ class Scheduler( _, _, available_size, evictable_size = self._get_token_info() protected_size = self.tree_cache.protected_size() memory_leak = (available_size + evictable_size) != ( + # self.max_total_num_tokens + # if not self.enable_hierarchical_cache + # else self.max_total_num_tokens - protected_size self.max_total_num_tokens - if not self.enable_hierarchical_cache - else self.max_total_num_tokens - protected_size + - protected_size ) token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index af56c580a..fab917a81 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -369,7 +369,6 @@ class MHATokenToKVPool(KVCache): # same applies to get_value_buffer and get_kv_buffer if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) - return self._get_key_buffer(layer_id) def _get_value_buffer(self, layer_id: int): diff --git a/python/sglang/srt/mem_cache/storage/lmcache/README.md b/python/sglang/srt/mem_cache/storage/lmcache/README.md new file mode 100644 index 000000000..7177e21e5 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/lmcache/README.md @@ -0,0 +1,43 @@ +# LMCache Connector for SGLang + +This document describes how to use LMCache as KV Cache Management Backend for SGLang engine. +For more details about LMCache, please refer to: https://lmcache.ai + +## Install LMCache + +### Method 1: with pip + +```bash +pip install lmcache +``` + +### Method 2: from source + +Clone LMCache project: + +```bash +git clone https://github.com/LMCache/LMCache +``` + +Install: + +```bash +cd LMCache +pip install -e . --no-build-isolation +``` + + +## Use LMCache + +Firstly, setup LMCache config. An example config is set at `example_config.yaml`. For more settings please refer to https://docs.lmcache.ai/api_reference/configurations.html. + +Secondly, setup SGLang serving engine with lmcache: + +```bash +export LMCACHE_USE_EXPERIMENTAL=True +export LMCACHE_CONFIG_FILE=example_config.yaml + +python -m sglang.launch_server \ + --model-path MODEL \ + --enable-lmcache +``` diff --git a/python/sglang/srt/mem_cache/storage/lmcache/example_config.yaml b/python/sglang/srt/mem_cache/storage/lmcache/example_config.yaml new file mode 100644 index 000000000..549110b7c --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/lmcache/example_config.yaml @@ -0,0 +1,7 @@ +# Basic configurations +chunk_size: 256 + +# CPU offloading configurations +local_cpu: true +use_layerwise: true +max_local_cpu_size: 10 # number of CPU backend GB diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py new file mode 100644 index 000000000..f8690aec4 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, List, Optional + +import torch + +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import MatchResult +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode + +try: + from lmcache.integration.sglang.sglang_adapter import ( + LMCacheLayerwiseConnector, + LoadMetadata, + StoreMetadata, + ) +except ImportError as e: + raise RuntimeError( + "LMCache is not installed. Please install it by running `pip install lmcache`" + ) from e + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.managers.schedule_batch import Req + +logger = logging.getLogger(__name__) + + +class LayerTransferCounter: + """Minimal adapter that lets the memory pool notify LMCache per-layer. + + The KV pool calls `wait_until(layer_id)` after finishing a layer, which we + translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector + within the provided CUDA stream. + """ + + def __init__( + self, + num_layers: int, + load_stream: torch.cuda.Stream, + lmc_connector: LMCacheLayerwiseConnector, + printable: bool = False, + ): + self.num_layers = num_layers + self.load_stream = load_stream + self.lmc_connector = lmc_connector + + def wait_until(self, layer_id: int): + # Ensure ordering of the async loads wrt compute stream(s). + self.load_stream.synchronize() + with self.load_stream: + self.lmc_connector.load_kv_layerwise(layer_id) + + +class LMCRadixCache(RadixCache): + """RadixCache + LMCache IO. + + This subclass adds: + - LMCache connector setup (device/host buffers, TP rank/size) + - Two CUDA streams for async load/store + - Layer-wise transfer executor wiring to the KV cache + - Overridden `match_prefix` to fetch missing prefix chunks from LMCache + - Extended cache_finalization paths to store back into LMCache + - Eviction barrier that respects any in-flight host->device stores + """ + + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, + page_size: int, + disable: bool = False, + enable_kv_cache_events: bool = False, + model_config: Optional["ModelConfig"] = None, + tp_size: int = 1, + rank: int = 0, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + ): + super().__init__( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=token_to_kv_pool_allocator, + page_size=page_size, + disable=disable, + enable_kv_cache_events=enable_kv_cache_events, + ) + + kvcache = self.token_to_kv_pool_allocator.get_kvcache() + self.lmcache_connector = LMCacheLayerwiseConnector( + sgl_config=model_config, + tp_size=tp_size, + rank=rank, + # NOTE: The original implementation accessed private buffers via + # `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when + # available; fall back to private fields if needed. + k_pool=getattr( + kvcache, + "k_buffer", + getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"), + ), + v_pool=getattr( + kvcache, + "v_buffer", + getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"), + ), + tp_group=tp_group, + ) + + self.load_stream = torch.cuda.Stream() + self.store_stream = torch.cuda.Stream() + + self.layer_done_executor = LayerTransferCounter( + num_layers=( + model_config.num_hidden_layers if model_config is not None else 0 + ), + load_stream=self.load_stream, + lmc_connector=self.lmcache_connector, + ) + kvcache.register_layer_transfer_counter(self.layer_done_executor) + + self._in_flight_nodes: list[TreeNode] = [] + self._node_lock = threading.Lock() + + def reset(self): # type: ignore[override] + super().reset() + if hasattr(self, "_in_flight_nodes"): + with self._node_lock: + self._in_flight_nodes.clear() + + def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override] + """Match cached prefix; if there's a tail miss, prefetch from LMCache. + + Reuses the base matching logic to obtain (value, last_node). If there + remains a *page-aligned* uncached suffix and there is room (or after + eviction), we allocate token slots and trigger an async LMCache load + into those slots, then materialize a new child node for the retrieved + chunk. + """ + if self.disable or not key: + return super().match_prefix(key, **kwargs) + + if self.page_size != 1: + aligned_len = len(key) // self.page_size * self.page_size + key = key[:aligned_len] + + base_res = super().match_prefix(key, **kwargs) + value: torch.Tensor = base_res.device_indices + last_node: TreeNode = base_res.last_device_node + + if value.numel() == len(key): + return base_res + + uncached_len = len(key) - value.numel() + if uncached_len == 0: + return base_res + + chunk_size = self.lmcache_connector.chunk_size() + prefix_pad = value.numel() % chunk_size + + if self.token_to_kv_pool_allocator.available_size() < uncached_len: + self.evict(uncached_len) + + token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len) + if token_slots is None: + return base_res + + slot_mapping = torch.cat( + [ + torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device), + token_slots.detach().clone().to(torch.int64).to(self.device), + ] + ) + + with torch.cuda.stream(self.load_stream): + num_retrieved = self.lmcache_connector.start_load_kv( + LoadMetadata( + token_ids=key, # full page-aligned key + slot_mapping=slot_mapping, + offset=value.numel() - prefix_pad, # LMCache offset convention + ) + ) + logger.debug("num_retrieved_tokens: %s", num_retrieved) + + if num_retrieved > 0: + self.token_to_kv_pool_allocator.free( + token_slots[(num_retrieved - prefix_pad) :] + ) + else: + self.token_to_kv_pool_allocator.free(token_slots) + + if num_retrieved > 0: + fetched = num_retrieved - prefix_pad + new_node = TreeNode() + start = value.numel() + end = start + fetched + new_node.key = key[start:end] + new_node.value = token_slots[:fetched] + new_node.parent = last_node + last_node.children[self.get_child_key_fn(new_node.key)] = new_node + last_node = new_node + + value = torch.cat([value, token_slots[:fetched]]) + self.evictable_size_ += fetched + + self._record_store_event(new_node.parent) + self._record_store_event(new_node) + + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_node, + ) + + return base_res + + def cache_finished_req(self, req: "Req") -> None: # type: ignore[override] + """On request completion, insert device KV into radix and store to LMCache.""" + + super().cache_finished_req(req) + + token_ids = (req.origin_input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + _, new_last_node, _, _ = self.match_prefix(token_ids) + assert new_last_node is not None + + self.inc_lock_ref(new_last_node) + store_md = StoreMetadata( + last_node=new_last_node, + token_ids=token_ids, + kv_indices=kv_indices, + offset=0, + ) + with torch.cuda.stream(self.store_stream): + self.lmcache_connector.store_kv(store_md) + with self._node_lock: + self._in_flight_nodes.append(new_last_node) + + def evict(self, num_tokens: int) -> None: # type: ignore[override] + """Before base eviction, wait for any outstanding stores and release locks.""" + if self.disable: + return + + self.store_stream.synchronize() + with self._node_lock: + for node in self._in_flight_nodes: + self.dec_lock_ref(node) + self._in_flight_nodes.clear() + + super().evict(num_tokens) + + def pretty_print(self): # type: ignore[override] + super().pretty_print() + try: + logger.debug( + "evictable=%d protected=%d", self.evictable_size_, self.protected_size_ + ) + except Exception: # pragma: no cover + pass + + +if __name__ == "__main__": + cache = LMCRadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=1, + disable=False, + enable_kv_cache_events=False, + model_config=None, + tp_size=1, + rank=0, + tp_group=None, + ) + cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64)) + cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64)) + cache.pretty_print() diff --git a/python/sglang/srt/mem_cache/storage/lmcache/unit_test.py b/python/sglang/srt/mem_cache/storage/lmcache/unit_test.py new file mode 100644 index 000000000..68dfe939d --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/lmcache/unit_test.py @@ -0,0 +1,121 @@ +try: + from lmcache.integration.sglang.sglang_adapter import ( + LMCacheLayerwiseConnector, + LoadMetadata, + StoreMetadata, + ) +except ImportError: + raise RuntimeError( + "LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache" + ) + +import os + +import torch + +from sglang.srt.configs.model_config import ModelConfig + +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml" + + +def test_load_store_metadata(): + model_config = ModelConfig( + model_path="Qwen/Qwen3-4B", + ) + + # Generate Dummy KV Cache + head_num = model_config.num_key_value_heads + head_dim = model_config.head_dim + layer_num = model_config.num_hidden_layers + buffer_size = 256 + input_id_len = 16 + + k_buffer = [ + torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda() + for _ in range(layer_num) + ] + v_buffer = [ + torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda() + for _ in range(layer_num) + ] + + connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer) + + fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist() + fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,)) + offset = 0 + + store_metadata = StoreMetadata( + last_node=None, + token_ids=fake_token_ids, + kv_indices=fake_kv_indices, + offset=offset, + ) + + load_metadata = LoadMetadata( + token_ids=fake_token_ids, + slot_mapping=fake_kv_indices, + offset=offset, + ) + + current_stream = torch.cuda.current_stream() + + retrieve_token_num = connector.start_load_kv(load_metadata) + assert retrieve_token_num == 0 + + connector.store_kv(store_metadata) + current_stream.synchronize() + + # check retrieve + gt_key_buffer = [ + torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda() + for _ in range(layer_num) + ] + gt_value_buffer = [ + torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda() + for _ in range(layer_num) + ] + + for i in range(layer_num): + gt_key_buffer[i] = k_buffer[i][fake_kv_indices] + gt_value_buffer[i] = v_buffer[i][fake_kv_indices] + + # clear the k_buffer and v_buffer + for _ in range(layer_num): + k_buffer[i].zero_() + v_buffer[i].zero_() + + retrieve_token_num = connector.start_load_kv(load_metadata) + assert retrieve_token_num == input_id_len + + for i in range(layer_num): + current_stream.synchronize() + connector.load_kv_layerwise(i) + + current_stream.synchronize() + test_key_buffer = [ + torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda() + for _ in range(layer_num) + ] + test_value_buffer = [ + torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda() + for _ in range(layer_num) + ] + + for i in range(layer_num): + test_key_buffer[i] = k_buffer[i][fake_kv_indices] + test_value_buffer[i] = v_buffer[i][fake_kv_indices] + + for i in range(layer_num): + assert torch.allclose(test_key_buffer[i], gt_key_buffer[i]) + assert torch.allclose(test_value_buffer[i], gt_value_buffer[i]) + + print("================================================") + print("TEST_LOAD_STORE_METADATA PASSED!") + print("================================================") + connector.close() + + +if __name__ == "__main__": + test_load_store_metadata() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 733c88da8..c7f5a69a1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -303,6 +303,8 @@ class ServerArgs: hicache_storage_backend: Optional[str] = None hicache_storage_prefetch_policy: str = "best_effort" hicache_storage_backend_extra_config: Optional[str] = None + # LMCache + enable_lmcache: bool = False # Double Sparsity enable_double_sparsity: bool = False @@ -1735,6 +1737,12 @@ class ServerArgs: default=ServerArgs.hicache_storage_backend_extra_config, help="A dictionary in JSON string format containing extra configuration for the storage backend.", ) + # LMCache + parser.add_argument( + "--enable-lmcache", + action="store_true", + help="Using LMCache as an alternative hierarchical cache solution", + ) # Double Sparsity parser.add_argument(