[Feature] LMCache Connector Integration (#9741)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com> Signed-off-by: YuhanLiu11 <yliu738@wisc.edu> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -656,6 +656,21 @@ class Scheduler(
|
|||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
disable=server_args.disable_radix_cache,
|
disable=server_args.disable_radix_cache,
|
||||||
)
|
)
|
||||||
|
elif server_args.enable_lmcache:
|
||||||
|
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||||
|
LMCRadixCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tree_cache = LMCRadixCache(
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
page_size=self.page_size,
|
||||||
|
disable=server_args.disable_radix_cache,
|
||||||
|
model_config=self.model_config,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
rank=self.tp_rank,
|
||||||
|
tp_group=self.tp_group,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.tree_cache = RadixCache(
|
self.tree_cache = RadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
@@ -1411,9 +1426,11 @@ class Scheduler(
|
|||||||
_, _, available_size, evictable_size = self._get_token_info()
|
_, _, available_size, evictable_size = self._get_token_info()
|
||||||
protected_size = self.tree_cache.protected_size()
|
protected_size = self.tree_cache.protected_size()
|
||||||
memory_leak = (available_size + evictable_size) != (
|
memory_leak = (available_size + evictable_size) != (
|
||||||
|
# self.max_total_num_tokens
|
||||||
|
# if not self.enable_hierarchical_cache
|
||||||
|
# else self.max_total_num_tokens - protected_size
|
||||||
self.max_total_num_tokens
|
self.max_total_num_tokens
|
||||||
if not self.enable_hierarchical_cache
|
- protected_size
|
||||||
else self.max_total_num_tokens - protected_size
|
|
||||||
)
|
)
|
||||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||||
|
|
||||||
|
|||||||
@@ -369,7 +369,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
# same applies to get_value_buffer and get_kv_buffer
|
# same applies to get_value_buffer and get_kv_buffer
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
return self._get_key_buffer(layer_id)
|
return self._get_key_buffer(layer_id)
|
||||||
|
|
||||||
def _get_value_buffer(self, layer_id: int):
|
def _get_value_buffer(self, layer_id: int):
|
||||||
|
|||||||
43
python/sglang/srt/mem_cache/storage/lmcache/README.md
Normal file
43
python/sglang/srt/mem_cache/storage/lmcache/README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# LMCache Connector for SGLang
|
||||||
|
|
||||||
|
This document describes how to use LMCache as KV Cache Management Backend for SGLang engine.
|
||||||
|
For more details about LMCache, please refer to: https://lmcache.ai
|
||||||
|
|
||||||
|
## Install LMCache
|
||||||
|
|
||||||
|
### Method 1: with pip
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install lmcache
|
||||||
|
```
|
||||||
|
|
||||||
|
### Method 2: from source
|
||||||
|
|
||||||
|
Clone LMCache project:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LMCache/LMCache
|
||||||
|
```
|
||||||
|
|
||||||
|
Install:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd LMCache
|
||||||
|
pip install -e . --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Use LMCache
|
||||||
|
|
||||||
|
Firstly, setup LMCache config. An example config is set at `example_config.yaml`. For more settings please refer to https://docs.lmcache.ai/api_reference/configurations.html.
|
||||||
|
|
||||||
|
Secondly, setup SGLang serving engine with lmcache:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LMCACHE_USE_EXPERIMENTAL=True
|
||||||
|
export LMCACHE_CONFIG_FILE=example_config.yaml
|
||||||
|
|
||||||
|
python -m sglang.launch_server \
|
||||||
|
--model-path MODEL \
|
||||||
|
--enable-lmcache
|
||||||
|
```
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# Basic configurations
|
||||||
|
chunk_size: 256
|
||||||
|
|
||||||
|
# CPU offloading configurations
|
||||||
|
local_cpu: true
|
||||||
|
use_layerwise: true
|
||||||
|
max_local_cpu_size: 10 # number of CPU backend GB
|
||||||
280
python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py
Normal file
280
python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
||||||
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lmcache.integration.sglang.sglang_adapter import (
|
||||||
|
LMCacheLayerwiseConnector,
|
||||||
|
LoadMetadata,
|
||||||
|
StoreMetadata,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"LMCache is not installed. Please install it by running `pip install lmcache`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerTransferCounter:
|
||||||
|
"""Minimal adapter that lets the memory pool notify LMCache per-layer.
|
||||||
|
|
||||||
|
The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
|
||||||
|
translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
|
||||||
|
within the provided CUDA stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
load_stream: torch.cuda.Stream,
|
||||||
|
lmc_connector: LMCacheLayerwiseConnector,
|
||||||
|
printable: bool = False,
|
||||||
|
):
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.load_stream = load_stream
|
||||||
|
self.lmc_connector = lmc_connector
|
||||||
|
|
||||||
|
def wait_until(self, layer_id: int):
|
||||||
|
# Ensure ordering of the async loads wrt compute stream(s).
|
||||||
|
self.load_stream.synchronize()
|
||||||
|
with self.load_stream:
|
||||||
|
self.lmc_connector.load_kv_layerwise(layer_id)
|
||||||
|
|
||||||
|
|
||||||
|
class LMCRadixCache(RadixCache):
|
||||||
|
"""RadixCache + LMCache IO.
|
||||||
|
|
||||||
|
This subclass adds:
|
||||||
|
- LMCache connector setup (device/host buffers, TP rank/size)
|
||||||
|
- Two CUDA streams for async load/store
|
||||||
|
- Layer-wise transfer executor wiring to the KV cache
|
||||||
|
- Overridden `match_prefix` to fetch missing prefix chunks from LMCache
|
||||||
|
- Extended cache_finalization paths to store back into LMCache
|
||||||
|
- Eviction barrier that respects any in-flight host->device stores
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
|
page_size: int,
|
||||||
|
disable: bool = False,
|
||||||
|
enable_kv_cache_events: bool = False,
|
||||||
|
model_config: Optional["ModelConfig"] = None,
|
||||||
|
tp_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
tp_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
req_to_token_pool=req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||||
|
page_size=page_size,
|
||||||
|
disable=disable,
|
||||||
|
enable_kv_cache_events=enable_kv_cache_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
self.lmcache_connector = LMCacheLayerwiseConnector(
|
||||||
|
sgl_config=model_config,
|
||||||
|
tp_size=tp_size,
|
||||||
|
rank=rank,
|
||||||
|
# NOTE: The original implementation accessed private buffers via
|
||||||
|
# `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
|
||||||
|
# available; fall back to private fields if needed.
|
||||||
|
k_pool=getattr(
|
||||||
|
kvcache,
|
||||||
|
"k_buffer",
|
||||||
|
getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
|
||||||
|
),
|
||||||
|
v_pool=getattr(
|
||||||
|
kvcache,
|
||||||
|
"v_buffer",
|
||||||
|
getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
|
||||||
|
),
|
||||||
|
tp_group=tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.load_stream = torch.cuda.Stream()
|
||||||
|
self.store_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
self.layer_done_executor = LayerTransferCounter(
|
||||||
|
num_layers=(
|
||||||
|
model_config.num_hidden_layers if model_config is not None else 0
|
||||||
|
),
|
||||||
|
load_stream=self.load_stream,
|
||||||
|
lmc_connector=self.lmcache_connector,
|
||||||
|
)
|
||||||
|
kvcache.register_layer_transfer_counter(self.layer_done_executor)
|
||||||
|
|
||||||
|
self._in_flight_nodes: list[TreeNode] = []
|
||||||
|
self._node_lock = threading.Lock()
|
||||||
|
|
||||||
|
def reset(self): # type: ignore[override]
|
||||||
|
super().reset()
|
||||||
|
if hasattr(self, "_in_flight_nodes"):
|
||||||
|
with self._node_lock:
|
||||||
|
self._in_flight_nodes.clear()
|
||||||
|
|
||||||
|
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
|
||||||
|
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
|
||||||
|
|
||||||
|
Reuses the base matching logic to obtain (value, last_node). If there
|
||||||
|
remains a *page-aligned* uncached suffix and there is room (or after
|
||||||
|
eviction), we allocate token slots and trigger an async LMCache load
|
||||||
|
into those slots, then materialize a new child node for the retrieved
|
||||||
|
chunk.
|
||||||
|
"""
|
||||||
|
if self.disable or not key:
|
||||||
|
return super().match_prefix(key, **kwargs)
|
||||||
|
|
||||||
|
if self.page_size != 1:
|
||||||
|
aligned_len = len(key) // self.page_size * self.page_size
|
||||||
|
key = key[:aligned_len]
|
||||||
|
|
||||||
|
base_res = super().match_prefix(key, **kwargs)
|
||||||
|
value: torch.Tensor = base_res.device_indices
|
||||||
|
last_node: TreeNode = base_res.last_device_node
|
||||||
|
|
||||||
|
if value.numel() == len(key):
|
||||||
|
return base_res
|
||||||
|
|
||||||
|
uncached_len = len(key) - value.numel()
|
||||||
|
if uncached_len == 0:
|
||||||
|
return base_res
|
||||||
|
|
||||||
|
chunk_size = self.lmcache_connector.chunk_size()
|
||||||
|
prefix_pad = value.numel() % chunk_size
|
||||||
|
|
||||||
|
if self.token_to_kv_pool_allocator.available_size() < uncached_len:
|
||||||
|
self.evict(uncached_len)
|
||||||
|
|
||||||
|
token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
|
||||||
|
if token_slots is None:
|
||||||
|
return base_res
|
||||||
|
|
||||||
|
slot_mapping = torch.cat(
|
||||||
|
[
|
||||||
|
torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
|
||||||
|
token_slots.detach().clone().to(torch.int64).to(self.device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.load_stream):
|
||||||
|
num_retrieved = self.lmcache_connector.start_load_kv(
|
||||||
|
LoadMetadata(
|
||||||
|
token_ids=key, # full page-aligned key
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
offset=value.numel() - prefix_pad, # LMCache offset convention
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug("num_retrieved_tokens: %s", num_retrieved)
|
||||||
|
|
||||||
|
if num_retrieved > 0:
|
||||||
|
self.token_to_kv_pool_allocator.free(
|
||||||
|
token_slots[(num_retrieved - prefix_pad) :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.token_to_kv_pool_allocator.free(token_slots)
|
||||||
|
|
||||||
|
if num_retrieved > 0:
|
||||||
|
fetched = num_retrieved - prefix_pad
|
||||||
|
new_node = TreeNode()
|
||||||
|
start = value.numel()
|
||||||
|
end = start + fetched
|
||||||
|
new_node.key = key[start:end]
|
||||||
|
new_node.value = token_slots[:fetched]
|
||||||
|
new_node.parent = last_node
|
||||||
|
last_node.children[self.get_child_key_fn(new_node.key)] = new_node
|
||||||
|
last_node = new_node
|
||||||
|
|
||||||
|
value = torch.cat([value, token_slots[:fetched]])
|
||||||
|
self.evictable_size_ += fetched
|
||||||
|
|
||||||
|
self._record_store_event(new_node.parent)
|
||||||
|
self._record_store_event(new_node)
|
||||||
|
|
||||||
|
return MatchResult(
|
||||||
|
device_indices=value,
|
||||||
|
last_device_node=last_node,
|
||||||
|
last_host_node=last_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_res
|
||||||
|
|
||||||
|
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
||||||
|
"""On request completion, insert device KV into radix and store to LMCache."""
|
||||||
|
|
||||||
|
super().cache_finished_req(req)
|
||||||
|
|
||||||
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||||
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, : len(token_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
_, new_last_node, _, _ = self.match_prefix(token_ids)
|
||||||
|
assert new_last_node is not None
|
||||||
|
|
||||||
|
self.inc_lock_ref(new_last_node)
|
||||||
|
store_md = StoreMetadata(
|
||||||
|
last_node=new_last_node,
|
||||||
|
token_ids=token_ids,
|
||||||
|
kv_indices=kv_indices,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
with torch.cuda.stream(self.store_stream):
|
||||||
|
self.lmcache_connector.store_kv(store_md)
|
||||||
|
with self._node_lock:
|
||||||
|
self._in_flight_nodes.append(new_last_node)
|
||||||
|
|
||||||
|
def evict(self, num_tokens: int) -> None: # type: ignore[override]
|
||||||
|
"""Before base eviction, wait for any outstanding stores and release locks."""
|
||||||
|
if self.disable:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.store_stream.synchronize()
|
||||||
|
with self._node_lock:
|
||||||
|
for node in self._in_flight_nodes:
|
||||||
|
self.dec_lock_ref(node)
|
||||||
|
self._in_flight_nodes.clear()
|
||||||
|
|
||||||
|
super().evict(num_tokens)
|
||||||
|
|
||||||
|
def pretty_print(self): # type: ignore[override]
|
||||||
|
super().pretty_print()
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
"evictable=%d protected=%d", self.evictable_size_, self.protected_size_
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cache = LMCRadixCache(
|
||||||
|
req_to_token_pool=None,
|
||||||
|
token_to_kv_pool_allocator=None,
|
||||||
|
page_size=1,
|
||||||
|
disable=False,
|
||||||
|
enable_kv_cache_events=False,
|
||||||
|
model_config=None,
|
||||||
|
tp_size=1,
|
||||||
|
rank=0,
|
||||||
|
tp_group=None,
|
||||||
|
)
|
||||||
|
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
|
||||||
|
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
|
||||||
|
cache.pretty_print()
|
||||||
121
python/sglang/srt/mem_cache/storage/lmcache/unit_test.py
Normal file
121
python/sglang/srt/mem_cache/storage/lmcache/unit_test.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
try:
|
||||||
|
from lmcache.integration.sglang.sglang_adapter import (
|
||||||
|
LMCacheLayerwiseConnector,
|
||||||
|
LoadMetadata,
|
||||||
|
StoreMetadata,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
|
||||||
|
)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
|
||||||
|
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||||
|
os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_store_metadata():
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_path="Qwen/Qwen3-4B",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate Dummy KV Cache
|
||||||
|
head_num = model_config.num_key_value_heads
|
||||||
|
head_dim = model_config.head_dim
|
||||||
|
layer_num = model_config.num_hidden_layers
|
||||||
|
buffer_size = 256
|
||||||
|
input_id_len = 16
|
||||||
|
|
||||||
|
k_buffer = [
|
||||||
|
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
v_buffer = [
|
||||||
|
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
|
||||||
|
|
||||||
|
fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
|
||||||
|
fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
store_metadata = StoreMetadata(
|
||||||
|
last_node=None,
|
||||||
|
token_ids=fake_token_ids,
|
||||||
|
kv_indices=fake_kv_indices,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_metadata = LoadMetadata(
|
||||||
|
token_ids=fake_token_ids,
|
||||||
|
slot_mapping=fake_kv_indices,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
|
||||||
|
retrieve_token_num = connector.start_load_kv(load_metadata)
|
||||||
|
assert retrieve_token_num == 0
|
||||||
|
|
||||||
|
connector.store_kv(store_metadata)
|
||||||
|
current_stream.synchronize()
|
||||||
|
|
||||||
|
# check retrieve
|
||||||
|
gt_key_buffer = [
|
||||||
|
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
gt_value_buffer = [
|
||||||
|
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(layer_num):
|
||||||
|
gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
||||||
|
gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
||||||
|
|
||||||
|
# clear the k_buffer and v_buffer
|
||||||
|
for _ in range(layer_num):
|
||||||
|
k_buffer[i].zero_()
|
||||||
|
v_buffer[i].zero_()
|
||||||
|
|
||||||
|
retrieve_token_num = connector.start_load_kv(load_metadata)
|
||||||
|
assert retrieve_token_num == input_id_len
|
||||||
|
|
||||||
|
for i in range(layer_num):
|
||||||
|
current_stream.synchronize()
|
||||||
|
connector.load_kv_layerwise(i)
|
||||||
|
|
||||||
|
current_stream.synchronize()
|
||||||
|
test_key_buffer = [
|
||||||
|
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
test_value_buffer = [
|
||||||
|
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(layer_num):
|
||||||
|
test_key_buffer[i] = k_buffer[i][fake_kv_indices]
|
||||||
|
test_value_buffer[i] = v_buffer[i][fake_kv_indices]
|
||||||
|
|
||||||
|
for i in range(layer_num):
|
||||||
|
assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
|
||||||
|
assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
|
||||||
|
|
||||||
|
print("================================================")
|
||||||
|
print("TEST_LOAD_STORE_METADATA PASSED!")
|
||||||
|
print("================================================")
|
||||||
|
connector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_load_store_metadata()
|
||||||
@@ -303,6 +303,8 @@ class ServerArgs:
|
|||||||
hicache_storage_backend: Optional[str] = None
|
hicache_storage_backend: Optional[str] = None
|
||||||
hicache_storage_prefetch_policy: str = "best_effort"
|
hicache_storage_prefetch_policy: str = "best_effort"
|
||||||
hicache_storage_backend_extra_config: Optional[str] = None
|
hicache_storage_backend_extra_config: Optional[str] = None
|
||||||
|
# LMCache
|
||||||
|
enable_lmcache: bool = False
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
enable_double_sparsity: bool = False
|
enable_double_sparsity: bool = False
|
||||||
@@ -1735,6 +1737,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_storage_backend_extra_config,
|
default=ServerArgs.hicache_storage_backend_extra_config,
|
||||||
help="A dictionary in JSON string format containing extra configuration for the storage backend.",
|
help="A dictionary in JSON string format containing extra configuration for the storage backend.",
|
||||||
)
|
)
|
||||||
|
# LMCache
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-lmcache",
|
||||||
|
action="store_true",
|
||||||
|
help="Using LMCache as an alternative hierarchical cache solution",
|
||||||
|
)
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user