Hierarchical Caching supports MLA (#4009)
Signed-off-by: Changqi Lu <luchangqi.123@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -22,10 +22,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
|
||||||
MHATokenToKVPoolHost,
|
|
||||||
TokenToKVPoolAllocator,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -151,7 +148,7 @@ class HiCacheController:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
mem_pool_host: MHATokenToKVPoolHost,
|
mem_pool_host: HostKVCache,
|
||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event = None,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
MHATokenToKVPool,
|
||||||
MHATokenToKVPoolHost,
|
MHATokenToKVPoolHost,
|
||||||
|
MLATokenToKVPool,
|
||||||
|
MLATokenToKVPoolHost,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
TokenToKVPoolAllocator,
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Page size larger than 1 is not yet supported in HiRadixCache."
|
"Page size larger than 1 is not yet supported in HiRadixCache."
|
||||||
)
|
)
|
||||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||||
token_to_kv_pool_allocator.get_kvcache()
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||||
)
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache)
|
||||||
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||||
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
||||||
|
|
||||||
self.tp_group = tp_cache_group
|
self.tp_group = tp_cache_group
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
|
|
||||||
@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache):
|
|||||||
prefix_len = _key_match(child.key, key)
|
prefix_len = _key_match(child.key, key)
|
||||||
if prefix_len < len(child.key):
|
if prefix_len < len(child.key):
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
new_node = self._split_node(child.key, child, prefix_len)
|
||||||
self.inc_hit_count(new_node)
|
|
||||||
if not new_node.evicted:
|
if not new_node.evicted:
|
||||||
value.append(new_node.value)
|
value.append(new_node.value)
|
||||||
node = new_node
|
node = new_node
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.inc_hit_count(child)
|
|
||||||
if not child.evicted:
|
if not child.evicted:
|
||||||
value.append(child.value)
|
value.append(child.value)
|
||||||
node = child
|
node = child
|
||||||
|
|||||||
@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_flat_data(self, indices):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def transfer(self, indices, flat_data):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||||
|
self.layer_transfer_counter = layer_transfer_counter
|
||||||
|
|
||||||
|
|
||||||
class TokenToKVPoolAllocator:
|
class TokenToKVPoolAllocator:
|
||||||
"""An allocator managing the indices to kv cache data."""
|
"""An allocator managing the indices to kv cache data."""
|
||||||
@@ -275,9 +290,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.k_buffer[i][indices] = k_data[i]
|
self.k_buffer[i][indices] = k_data[i]
|
||||||
self.v_buffer[i][indices] = v_data[i]
|
self.v_buffer[i][indices] = v_data[i]
|
||||||
|
|
||||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
|
||||||
self.layer_transfer_counter = layer_transfer_counter
|
|
||||||
|
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
# transfer prepared data from host to device
|
# transfer prepared data from host to device
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||||
@@ -388,6 +400,8 @@ class MLATokenToKVPool(KVCache):
|
|||||||
else:
|
else:
|
||||||
self.store_dtype = dtype
|
self.store_dtype = dtype
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.layer_num = layer_num
|
||||||
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=enable_memory_saver
|
enable=enable_memory_saver
|
||||||
@@ -404,12 +418,20 @@ class MLATokenToKVPool(KVCache):
|
|||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.layer_transfer_counter = None
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
if self.layer_transfer_counter is not None:
|
||||||
|
self.layer_transfer_counter.wait_until(layer_id)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.kv_buffer[layer_id].view(self.dtype)
|
return self.kv_buffer[layer_id].view(self.dtype)
|
||||||
return self.kv_buffer[layer_id]
|
return self.kv_buffer[layer_id]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
if self.layer_transfer_counter is not None:
|
||||||
|
self.layer_transfer_counter.wait_until(layer_id)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
||||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
||||||
@@ -432,6 +454,22 @@ class MLATokenToKVPool(KVCache):
|
|||||||
else:
|
else:
|
||||||
self.kv_buffer[layer_id][loc] = cache_k
|
self.kv_buffer[layer_id][loc] = cache_k
|
||||||
|
|
||||||
|
def get_flat_data(self, indices):
|
||||||
|
# prepare a large chunk of contiguous data for efficient transfer
|
||||||
|
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
||||||
|
|
||||||
|
@debug_timing
|
||||||
|
def transfer(self, indices, flat_data):
|
||||||
|
# transfer prepared data from host to device
|
||||||
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||||
|
for i in range(self.layer_num):
|
||||||
|
self.kv_buffer[i][indices] = flat_data[i]
|
||||||
|
|
||||||
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
|
# transfer prepared data from host to device
|
||||||
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||||
|
self.kv_buffer[layer_id][indices] = flat_data
|
||||||
|
|
||||||
|
|
||||||
class DoubleSparseTokenToKVPool(KVCache):
|
class DoubleSparseTokenToKVPool(KVCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -508,6 +546,15 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id][loc] = cache_v
|
||||||
self.label_buffer[layer_id][loc] = cache_label
|
self.label_buffer[layer_id][loc] = cache_label
|
||||||
|
|
||||||
|
def get_flat_data(self, indices):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def transfer(self, indices, flat_data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MemoryStateInt(IntEnum):
|
class MemoryStateInt(IntEnum):
|
||||||
IDLE = 0
|
IDLE = 0
|
||||||
@@ -526,7 +573,7 @@ def synchronized(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPoolHost:
|
class HostKVCache(abc.ABC):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -547,12 +594,7 @@ class MHATokenToKVPoolHost:
|
|||||||
|
|
||||||
self.size = int(device_pool.size * host_to_device_ratio)
|
self.size = int(device_pool.size * host_to_device_ratio)
|
||||||
self.dtype = device_pool.store_dtype
|
self.dtype = device_pool.store_dtype
|
||||||
self.head_num = device_pool.head_num
|
self.size_per_token = self.get_size_per_token()
|
||||||
self.head_dim = device_pool.head_dim
|
|
||||||
self.layer_num = device_pool.layer_num
|
|
||||||
self.size_per_token = (
|
|
||||||
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify there is enough available host memory.
|
# Verify there is enough available host memory.
|
||||||
host_mem = psutil.virtual_memory()
|
host_mem = psutil.virtual_memory()
|
||||||
@@ -571,12 +613,7 @@ class MHATokenToKVPoolHost:
|
|||||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_buffer = torch.zeros(
|
self.kv_buffer = self.init_kv_buffer()
|
||||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
pin_memory=self.pin_memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize memory states and tracking structures.
|
# Initialize memory states and tracking structures.
|
||||||
self.mem_state = torch.zeros(
|
self.mem_state = torch.zeros(
|
||||||
@@ -588,21 +625,29 @@ class MHATokenToKVPoolHost:
|
|||||||
# A lock for synchronized operations on memory allocation and state transitions.
|
# A lock for synchronized operations on memory allocation and state transitions.
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
@abc.abstractmethod
|
||||||
return self.kv_buffer[:, :, indices]
|
def get_size_per_token(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_flat_data_by_layer(self, indices, layer_id):
|
@abc.abstractmethod
|
||||||
return self.kv_buffer[:, layer_id, indices]
|
def init_kv_buffer(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
@abc.abstractmethod
|
||||||
self.kv_buffer[:, :, indices] = flat_data
|
|
||||||
|
|
||||||
@debug_timing
|
|
||||||
def transfer(self, indices, flat_data):
|
def transfer(self, indices, flat_data):
|
||||||
# backup prepared data from device to host
|
raise NotImplementedError()
|
||||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
|
||||||
device=self.device, non_blocking=False
|
@abc.abstractmethod
|
||||||
)
|
def get_flat_data(self, indices):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_flat_data_by_layer(self, indices, layer_id):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def assign_flat_data(self, indices, flat_data):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def clear(self):
|
def clear(self):
|
||||||
@@ -694,3 +739,92 @@ class MHATokenToKVPoolHost:
|
|||||||
self.free_slots = torch.concat([self.free_slots, indices])
|
self.free_slots = torch.concat([self.free_slots, indices])
|
||||||
self.can_use_mem_size += len(indices)
|
self.can_use_mem_size += len(indices)
|
||||||
return len(indices)
|
return len(indices)
|
||||||
|
|
||||||
|
|
||||||
|
class MHATokenToKVPoolHost(HostKVCache):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device_pool: MHATokenToKVPool,
|
||||||
|
host_to_device_ratio: float = 3.0,
|
||||||
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
||||||
|
|
||||||
|
def get_size_per_token(self):
|
||||||
|
self.head_num = self.device_pool.head_num
|
||||||
|
self.head_dim = self.device_pool.head_dim
|
||||||
|
self.layer_num = self.device_pool.layer_num
|
||||||
|
|
||||||
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||||
|
|
||||||
|
def init_kv_buffer(self):
|
||||||
|
return torch.empty(
|
||||||
|
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
@debug_timing
|
||||||
|
def transfer(self, indices, flat_data):
|
||||||
|
# backup prepared data from device to host
|
||||||
|
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||||
|
device=self.device, non_blocking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_flat_data(self, indices):
|
||||||
|
return self.kv_buffer[:, :, indices]
|
||||||
|
|
||||||
|
def get_flat_data_by_layer(self, indices, layer_id):
|
||||||
|
return self.kv_buffer[:, layer_id, indices]
|
||||||
|
|
||||||
|
def assign_flat_data(self, indices, flat_data):
|
||||||
|
self.kv_buffer[:, :, indices] = flat_data
|
||||||
|
|
||||||
|
|
||||||
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device_pool: MLATokenToKVPool,
|
||||||
|
host_to_device_ratio: float = 4.0,
|
||||||
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
||||||
|
|
||||||
|
def get_size_per_token(self):
|
||||||
|
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
||||||
|
self.layer_num = self.device_pool.layer_num
|
||||||
|
|
||||||
|
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
|
||||||
|
|
||||||
|
def init_kv_buffer(self):
|
||||||
|
return torch.empty(
|
||||||
|
(
|
||||||
|
self.layer_num,
|
||||||
|
self.size,
|
||||||
|
1,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
@debug_timing
|
||||||
|
def transfer(self, indices, flat_data):
|
||||||
|
# backup prepared data from device to host
|
||||||
|
self.kv_buffer[:, indices] = flat_data.to(
|
||||||
|
device=self.device, non_blocking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_flat_data(self, indices):
|
||||||
|
return self.kv_buffer[:, indices]
|
||||||
|
|
||||||
|
def get_flat_data_by_layer(self, indices, layer_id):
|
||||||
|
return self.kv_buffer[layer_id, indices]
|
||||||
|
|
||||||
|
def assign_flat_data(self, indices, flat_data):
|
||||||
|
self.kv_buffer[:, indices] = flat_data
|
||||||
|
|||||||
56
test/srt/test_hierarchical_mla.py
Normal file
56
test/srt/test_hierarchical_mla.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalMLA(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--trust-remote-code", "--enable-hierarchical-cache"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreater(metrics["score"], 0.5)
|
||||||
|
|
||||||
|
def test_mgsm_en(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mgsm_en",
|
||||||
|
num_examples=None,
|
||||||
|
num_threads=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreater(metrics["score"], 0.8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user