Hierarchical Caching supports MLA (#4009)
Signed-off-by: Changqi Lu <luchangqi.123@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -22,10 +22,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -151,7 +148,7 @@ class HiCacheController:
|
||||
def __init__(
|
||||
self,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
mem_pool_host: HostKVCache,
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
|
||||
@@ -8,7 +8,10 @@ import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPool,
|
||||
MLATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache):
|
||||
raise ValueError(
|
||||
"Page size larger than 1 is not yet supported in HiRadixCache."
|
||||
)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache)
|
||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache)
|
||||
else:
|
||||
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
self.page_size = page_size
|
||||
|
||||
@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache):
|
||||
prefix_len = _key_match(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
self.inc_hit_count(new_node)
|
||||
if not new_node.evicted:
|
||||
value.append(new_node.value)
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
if not child.evicted:
|
||||
value.append(child.value)
|
||||
node = child
|
||||
|
||||
@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data(self, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator:
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
@@ -275,9 +290,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
@@ -388,6 +400,8 @@ class MLATokenToKVPool(KVCache):
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.layer_num = layer_num
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
@@ -404,12 +418,20 @@ class MLATokenToKVPool(KVCache):
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.kv_buffer[layer_id].view(self.dtype)
|
||||
return self.kv_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
||||
@@ -432,6 +454,22 @@ class MLATokenToKVPool(KVCache):
|
||||
else:
|
||||
self.kv_buffer[layer_id][loc] = cache_k
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
for i in range(self.layer_num):
|
||||
self.kv_buffer[i][indices] = flat_data[i]
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
self.kv_buffer[layer_id][indices] = flat_data
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
@@ -508,6 +546,15 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
self.label_buffer[layer_id][loc] = cache_label
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
pass
|
||||
|
||||
def transfer(self, indices, flat_data):
|
||||
pass
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryStateInt(IntEnum):
|
||||
IDLE = 0
|
||||
@@ -526,7 +573,7 @@ def synchronized(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost:
|
||||
class HostKVCache(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -547,12 +594,7 @@ class MHATokenToKVPoolHost:
|
||||
|
||||
self.size = int(device_pool.size * host_to_device_ratio)
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.head_num = device_pool.head_num
|
||||
self.head_dim = device_pool.head_dim
|
||||
self.layer_num = device_pool.layer_num
|
||||
self.size_per_token = (
|
||||
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
)
|
||||
self.size_per_token = self.get_size_per_token()
|
||||
|
||||
# Verify there is enough available host memory.
|
||||
host_mem = psutil.virtual_memory()
|
||||
@@ -571,12 +613,7 @@ class MHATokenToKVPoolHost:
|
||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||
)
|
||||
|
||||
self.kv_buffer = torch.zeros(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.kv_buffer = self.init_kv_buffer()
|
||||
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
@@ -588,21 +625,29 @@ class MHATokenToKVPoolHost:
|
||||
# A lock for synchronized operations on memory allocation and state transitions.
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
@abc.abstractmethod
|
||||
def get_size_per_token(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id, indices]
|
||||
@abc.abstractmethod
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
@debug_timing
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data(self, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized
|
||||
def clear(self):
|
||||
@@ -694,3 +739,92 @@ class MHATokenToKVPoolHost:
|
||||
self.free_slots = torch.concat([self.free_slots, indices])
|
||||
self.can_use_mem_size += len(indices)
|
||||
return len(indices)
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost(HostKVCache):
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 3.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
||||
|
||||
def get_size_per_token(self):
|
||||
self.head_num = self.device_pool.head_num
|
||||
self.head_dim = self.device_pool.head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MLATokenToKVPool,
|
||||
host_to_device_ratio: float = 4.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
||||
|
||||
def get_size_per_token(self):
|
||||
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(
|
||||
self.layer_num,
|
||||
self.size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[layer_id, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, indices] = flat_data
|
||||
|
||||
Reference in New Issue
Block a user