Large page size aligned hierarchical caching (#4581)
This commit is contained in:
@@ -149,6 +149,7 @@ class HiCacheController:
|
|||||||
self,
|
self,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
mem_pool_host: HostKVCache,
|
mem_pool_host: HostKVCache,
|
||||||
|
page_size: int,
|
||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event = None,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
):
|
):
|
||||||
@@ -156,6 +157,7 @@ class HiCacheController:
|
|||||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||||
self.mem_pool_host = mem_pool_host
|
self.mem_pool_host = mem_pool_host
|
||||||
self.write_policy = write_policy
|
self.write_policy = write_policy
|
||||||
|
self.page_size = page_size
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.load_cache_event = load_cache_event
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||||
@@ -184,7 +186,12 @@ class HiCacheController:
|
|||||||
self.load_stream = torch.cuda.Stream()
|
self.load_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
self.write_thread = threading.Thread(
|
||||||
target=self.write_thread_func_buffer, daemon=True
|
target=(
|
||||||
|
self.write_thread_func_buffer
|
||||||
|
if self.page_size == 1
|
||||||
|
else self.write_thread_func_direct
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
self.load_thread = threading.Thread(
|
self.load_thread = threading.Thread(
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||||
@@ -205,7 +212,12 @@ class HiCacheController:
|
|||||||
self.ack_load_queue.queue.clear()
|
self.ack_load_queue.queue.clear()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
self.write_thread = threading.Thread(
|
||||||
target=self.write_thread_func_buffer, daemon=True
|
target=(
|
||||||
|
self.write_thread_func_buffer
|
||||||
|
if self.page_size == 1
|
||||||
|
else self.write_thread_func_direct
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
self.load_thread = threading.Thread(
|
self.load_thread = threading.Thread(
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||||
@@ -260,10 +272,12 @@ class HiCacheController:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
operation = self.write_queue.get(block=True, timeout=1)
|
||||||
operation.data = self.mem_pool_device.get_flat_data(
|
self.mem_pool_host.write_page_all_layers(
|
||||||
operation.device_indices
|
operation.host_indices,
|
||||||
|
operation.device_indices,
|
||||||
|
self.mem_pool_device,
|
||||||
)
|
)
|
||||||
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
self.write_stream.synchronize()
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
for node_id in operation.node_ids:
|
for node_id in operation.node_ids:
|
||||||
if node_id != 0:
|
if node_id != 0:
|
||||||
@@ -320,12 +334,21 @@ class HiCacheController:
|
|||||||
|
|
||||||
self.layer_done_counter.reset()
|
self.layer_done_counter.reset()
|
||||||
for i in range(self.mem_pool_host.layer_num):
|
for i in range(self.mem_pool_host.layer_num):
|
||||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
if self.page_size == 1:
|
||||||
batch_operation.host_indices, i
|
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||||
)
|
batch_operation.host_indices, i
|
||||||
self.mem_pool_device.transfer_per_layer(
|
)
|
||||||
batch_operation.device_indices, flat_data, i
|
self.mem_pool_device.transfer_per_layer(
|
||||||
)
|
batch_operation.device_indices, flat_data, i
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mem_pool_host.load_page_per_layer(
|
||||||
|
batch_operation.host_indices,
|
||||||
|
batch_operation.device_indices,
|
||||||
|
self.mem_pool_device,
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
self.load_stream.synchronize()
|
||||||
self.layer_done_counter.increment()
|
self.layer_done_counter.increment()
|
||||||
|
|
||||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||||
|
|||||||
@@ -1282,7 +1282,7 @@ class Scheduler(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
self.tree_cache.read_to_load_cache()
|
self.tree_cache.ready_to_load_cache()
|
||||||
|
|
||||||
if adder.new_chunked_req is not None:
|
if adder.new_chunked_req is not None:
|
||||||
assert self.chunked_req is None
|
assert self.chunked_req is None
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
TokenToKVPoolAllocator,
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
hicache_ratio: float,
|
hicache_ratio: float,
|
||||||
):
|
):
|
||||||
if page_size != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Page size larger than 1 is not yet supported in HiRadixCache."
|
|
||||||
)
|
|
||||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||||
self.kv_cache, hicache_ratio
|
self.kv_cache, hicache_ratio, page_size
|
||||||
)
|
)
|
||||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
||||||
self.kv_cache, hicache_ratio
|
self.kv_cache, hicache_ratio, page_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||||
|
|
||||||
self.tp_group = tp_cache_group
|
self.tp_group = tp_cache_group
|
||||||
self.page_size = page_size
|
|
||||||
|
|
||||||
self.load_cache_event = threading.Event()
|
self.load_cache_event = threading.Event()
|
||||||
self.cache_controller = HiCacheController(
|
self.cache_controller = HiCacheController(
|
||||||
token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator,
|
||||||
self.token_to_kv_pool_host,
|
self.token_to_kv_pool_host,
|
||||||
|
page_size,
|
||||||
load_cache_event=self.load_cache_event,
|
load_cache_event=self.load_cache_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
|
|||||||
self.write_through_threshold = 1
|
self.write_through_threshold = 1
|
||||||
self.load_back_threshold = 10
|
self.load_back_threshold = 10
|
||||||
super().__init__(
|
super().__init__(
|
||||||
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
return last_node, prefix_indices
|
return last_node, prefix_indices
|
||||||
|
|
||||||
def read_to_load_cache(self):
|
def ready_to_load_cache(self):
|
||||||
self.load_cache_event.set()
|
self.load_cache_event.set()
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||||
if self.disable:
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
return [], self.root_node
|
if self.disable or len(key) == 0:
|
||||||
|
if include_evicted:
|
||||||
|
return empty_value, self.root_node, self.root_node
|
||||||
|
else:
|
||||||
|
return empty_value, self.root_node
|
||||||
|
|
||||||
|
if self.page_size != 1:
|
||||||
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||||
|
key = key[:page_aligned_len]
|
||||||
|
|
||||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||||
if value:
|
if value:
|
||||||
value = torch.cat(value)
|
value = torch.cat(value)
|
||||||
else:
|
else:
|
||||||
value = torch.tensor([], dtype=torch.int64)
|
value = empty_value
|
||||||
|
|
||||||
last_node_global = last_node
|
last_node_global = last_node
|
||||||
while last_node.evicted:
|
while last_node.evicted:
|
||||||
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||||
node.last_access_time = time.time()
|
node.last_access_time = time.time()
|
||||||
|
child_key = self.get_child_key_fn(key)
|
||||||
value = []
|
value = []
|
||||||
while len(key) > 0 and key[0] in node.children.keys():
|
|
||||||
child = node.children[key[0]]
|
while len(key) > 0 and child_key in node.children.keys():
|
||||||
|
child = node.children[child_key]
|
||||||
child.last_access_time = time.time()
|
child.last_access_time = time.time()
|
||||||
prefix_len = _key_match(child.key, key)
|
prefix_len = self.key_match_fn(child.key, key)
|
||||||
if prefix_len < len(child.key):
|
if prefix_len < len(child.key):
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
new_node = self._split_node(child.key, child, prefix_len)
|
||||||
if not new_node.evicted:
|
if not new_node.evicted:
|
||||||
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
|
|||||||
value.append(child.value)
|
value.append(child.value)
|
||||||
node = child
|
node = child
|
||||||
key = key[prefix_len:]
|
key = key[prefix_len:]
|
||||||
|
|
||||||
|
if len(key):
|
||||||
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|
||||||
return value, node
|
return value, node
|
||||||
|
|
||||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||||
# child node split into new_node -> child
|
# child node split into new_node -> child
|
||||||
new_node = TreeNode()
|
new_node = TreeNode()
|
||||||
new_node.children = {key[split_len]: child}
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||||
new_node.parent = child.parent
|
new_node.parent = child.parent
|
||||||
new_node.lock_ref = child.lock_ref
|
new_node.lock_ref = child.lock_ref
|
||||||
new_node.key = child.key[:split_len]
|
new_node.key = child.key[:split_len]
|
||||||
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
|
|||||||
child.host_value = child.host_value[split_len:]
|
child.host_value = child.host_value[split_len:]
|
||||||
child.parent = new_node
|
child.parent = new_node
|
||||||
child.key = child.key[split_len:]
|
child.key = child.key[split_len:]
|
||||||
new_node.parent.children[key[0]] = new_node
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||||
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
|
|||||||
if len(key) == 0:
|
if len(key) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if key[0] in node.children.keys():
|
child_key = self.get_child_key_fn(key)
|
||||||
child = node.children[key[0]]
|
total_prefix_length = 0
|
||||||
prefix_len = _key_match(child.key, key)
|
|
||||||
|
|
||||||
if prefix_len == len(child.key):
|
while len(key) > 0 and child_key in node.children.keys():
|
||||||
if child.evicted:
|
node = node.children[child_key]
|
||||||
|
node.last_access_time = time.time()
|
||||||
|
prefix_len = self.key_match_fn(node.key, key)
|
||||||
|
|
||||||
|
if prefix_len == len(node.key):
|
||||||
|
if node.evicted:
|
||||||
# change the reference if the node is evicted
|
# change the reference if the node is evicted
|
||||||
# this often happens in the case of KV cache recomputation
|
# this often happens in the case of KV cache recomputation
|
||||||
child.value = value[:prefix_len]
|
node.value = value[:prefix_len]
|
||||||
self.token_to_kv_pool_host.update_synced(child.host_value)
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
||||||
self.evictable_size_ += len(value[:prefix_len])
|
self.evictable_size_ += len(node.value)
|
||||||
return self._insert_helper(
|
|
||||||
child, key[prefix_len:], value[prefix_len:]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.inc_hit_count(child)
|
self.inc_hit_count(node)
|
||||||
return prefix_len + self._insert_helper(
|
total_prefix_length += prefix_len
|
||||||
child, key[prefix_len:], value[prefix_len:]
|
|
||||||
)
|
|
||||||
|
|
||||||
# partial match, split the node
|
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
|
||||||
if new_node.evicted:
|
|
||||||
new_node.value = value[:prefix_len]
|
|
||||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
|
||||||
self.evictable_size_ += len(new_node.value)
|
|
||||||
return self._insert_helper(
|
|
||||||
new_node, key[prefix_len:], value[prefix_len:]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.inc_hit_count(new_node)
|
# partial match, split the node
|
||||||
return prefix_len + self._insert_helper(
|
new_node = self._split_node(node.key, node, prefix_len)
|
||||||
new_node, key[prefix_len:], value[prefix_len:]
|
if new_node.evicted:
|
||||||
)
|
new_node.value = value[:prefix_len]
|
||||||
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
||||||
|
self.evictable_size_ += len(new_node.value)
|
||||||
|
else:
|
||||||
|
self.inc_hit_count(new_node)
|
||||||
|
total_prefix_length += prefix_len
|
||||||
|
node = new_node
|
||||||
|
|
||||||
|
key = key[prefix_len:]
|
||||||
|
value = value[prefix_len:]
|
||||||
|
|
||||||
|
if len(key):
|
||||||
|
child_key = self.get_child_key_fn(key)
|
||||||
|
|
||||||
if len(key):
|
if len(key):
|
||||||
new_node = TreeNode()
|
new_node = TreeNode()
|
||||||
new_node.parent = node
|
new_node.parent = node
|
||||||
new_node.key = key
|
new_node.key = key
|
||||||
new_node.value = value
|
new_node.value = value
|
||||||
node.children[key[0]] = new_node
|
node.children[child_key] = new_node
|
||||||
self.evictable_size_ += len(value)
|
self.evictable_size_ += len(value)
|
||||||
|
|
||||||
if self.cache_controller.write_policy == "write_through":
|
if self.cache_controller.write_policy == "write_through":
|
||||||
self.write_backup(new_node)
|
self.write_backup(new_node)
|
||||||
return 0
|
return total_prefix_length
|
||||||
|
|
||||||
def _collect_leaves_device(self):
|
def _collect_leaves_device(self):
|
||||||
def is_leaf(node):
|
def is_leaf(node):
|
||||||
|
|||||||
@@ -608,8 +608,9 @@ class HostKVCache(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
device_pool: MHATokenToKVPool,
|
device_pool: MHATokenToKVPool,
|
||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
pin_memory: bool,
|
||||||
device: str = "cpu",
|
device: str,
|
||||||
|
page_size: int,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
host_to_device_ratio >= 1
|
host_to_device_ratio >= 1
|
||||||
@@ -620,8 +621,11 @@ class HostKVCache(abc.ABC):
|
|||||||
self.host_to_device_ratio = host_to_device_ratio
|
self.host_to_device_ratio = host_to_device_ratio
|
||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.page_size = page_size
|
||||||
|
|
||||||
self.size = int(device_pool.size * host_to_device_ratio)
|
self.size = int(device_pool.size * host_to_device_ratio)
|
||||||
|
# Align the host memory pool size to the page size
|
||||||
|
self.size = self.size - (self.size % self.page_size)
|
||||||
self.dtype = device_pool.store_dtype
|
self.dtype = device_pool.store_dtype
|
||||||
self.size_per_token = self.get_size_per_token()
|
self.size_per_token = self.get_size_per_token()
|
||||||
|
|
||||||
@@ -775,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
self,
|
self,
|
||||||
device_pool: MHATokenToKVPool,
|
device_pool: MHATokenToKVPool,
|
||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
page_size: int,
|
||||||
|
pin_memory: bool = True,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
super().__init__(
|
||||||
|
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
||||||
|
)
|
||||||
|
|
||||||
def get_size_per_token(self):
|
def get_size_per_token(self):
|
||||||
self.head_num = self.device_pool.head_num
|
self.head_num = self.device_pool.head_num
|
||||||
@@ -811,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
def assign_flat_data(self, indices, flat_data):
|
def assign_flat_data(self, indices, flat_data):
|
||||||
self.kv_buffer[:, :, indices] = flat_data
|
self.kv_buffer[:, :, indices] = flat_data
|
||||||
|
|
||||||
|
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||||
|
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||||
|
for i in range(len(device_indices_cpu)):
|
||||||
|
h_index = host_indices[i * self.page_size]
|
||||||
|
d_index = device_indices_cpu[i]
|
||||||
|
for j in range(self.layer_num):
|
||||||
|
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
||||||
|
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
||||||
|
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||||
|
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||||
|
for i in range(len(device_indices_cpu)):
|
||||||
|
h_index = host_indices[i * self.page_size]
|
||||||
|
d_index = device_indices_cpu[i]
|
||||||
|
device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
||||||
|
self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
||||||
|
self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost(HostKVCache):
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_pool: MLATokenToKVPool,
|
device_pool: MLATokenToKVPool,
|
||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
page_size: int,
|
||||||
|
pin_memory: bool = True,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
super().__init__(
|
||||||
|
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
||||||
|
)
|
||||||
|
|
||||||
def get_size_per_token(self):
|
def get_size_per_token(self):
|
||||||
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
||||||
@@ -857,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
def assign_flat_data(self, indices, flat_data):
|
||||||
self.kv_buffer[:, indices] = flat_data
|
self.kv_buffer[:, indices] = flat_data
|
||||||
|
|
||||||
|
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||||
|
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||||
|
for i in range(len(device_indices_cpu)):
|
||||||
|
h_index = host_indices[i * self.page_size]
|
||||||
|
d_index = device_indices_cpu[i]
|
||||||
|
for j in range(self.layer_num):
|
||||||
|
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
||||||
|
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||||
|
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||||
|
for i in range(len(device_indices_cpu)):
|
||||||
|
h_index = host_indices[i * self.page_size]
|
||||||
|
d_index = device_indices_cpu[i]
|
||||||
|
device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
||||||
|
self.kv_buffer[layer_id, h_index : h_index + self.page_size],
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_pages) * self.page_size
|
return len(self.free_pages) * self.page_size
|
||||||
|
|
||||||
|
def get_kvcache(self):
|
||||||
|
return self._kvcache
|
||||||
|
|
||||||
|
def alloc(self, need_size: int):
|
||||||
|
# page-aligned allocation, returning contiguous indices of pages
|
||||||
|
if self.debug_mode:
|
||||||
|
assert (
|
||||||
|
need_size % self.page_size == 0
|
||||||
|
), "The allocation size should be page-aligned"
|
||||||
|
|
||||||
|
num_pages = need_size // self.page_size
|
||||||
|
if num_pages > len(self.free_pages):
|
||||||
|
return None
|
||||||
|
|
||||||
|
out_pages = self.free_pages[:num_pages]
|
||||||
|
self.free_pages = self.free_pages[num_pages:]
|
||||||
|
|
||||||
|
out_indices = (
|
||||||
|
out_pages[:, None] * self.page_size
|
||||||
|
+ torch.arange(self.page_size, device=self.device)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
return out_indices
|
||||||
|
|
||||||
def alloc_extend(
|
def alloc_extend(
|
||||||
self,
|
self,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestPageSize(CustomTestCase):
|
class TestHiCache(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
@@ -21,7 +21,9 @@ class TestPageSize(CustomTestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--enable-hierarchical-cache"],
|
other_args=[
|
||||||
|
"--enable-hierarchical-cache",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -21,7 +21,10 @@ class TestHierarchicalMLA(CustomTestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--trust-remote-code", "--enable-hierarchical-cache"],
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--enable-hierarchical-cache",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
49
test/srt/test_hicache_page.py
Normal file
49
test/srt/test_hicache_page.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHiCachePage(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--enable-hierarchical-cache",
|
||||||
|
"--page-size",
|
||||||
|
"32",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user