Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.tp_group = tp_cache_group
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator, self.token_to_kv_pool_host
|
||||
token_to_kv_pool_allocator,
|
||||
self.token_to_kv_pool_host,
|
||||
load_cache_event=self.load_cache_event,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
|
||||
def write_backup(self, node: TreeNode):
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is None:
|
||||
self.evict_host(len(node.value))
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is not None:
|
||||
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
|
||||
node.hit_count = 0
|
||||
|
||||
def writing_check(self):
|
||||
while not self.cache_controller.ack_write_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_write_queue.get_nowait()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
# clear the reference
|
||||
del self.ongoing_write_through[ack_id]
|
||||
except Exception:
|
||||
break
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to radix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
del self.ongoing_write_through[ack_id]
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
|
||||
break
|
||||
|
||||
def evictable_size(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
return self.evictable_size_
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback=None):
|
||||
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return device_indices
|
||||
|
||||
def loading_complete(self, node: TreeNode):
|
||||
self.loading_check()
|
||||
return node.loading == False
|
||||
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def read_to_load_cache(self):
|
||||
self.load_cache_event.set()
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
last_node_global = last_node
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
if include_evicted:
|
||||
return value, last_node, last_node_global
|
||||
else:
|
||||
return value, last_node
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
value = []
|
||||
|
||||
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
||||
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
self.k_buffer[layer_id][indices] = k_data
|
||||
self.v_buffer[layer_id][indices] = v_data
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id].view(self.dtype)
|
||||
return self.v_buffer[layer_id]
|
||||
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user