Hierarchical Caching Refactoring and Fixing TP issue (#4082)

This commit is contained in:
Zhiqiang Xie
2025-03-12 11:22:35 -07:00
committed by GitHub
parent 01090e8ac3
commit 10b544ae9b
6 changed files with 194 additions and 56 deletions

View File

@@ -1,11 +1,13 @@
import heapq
import logging
import threading
import time
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
ReqToTokenPool,
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.tp_group = tp_cache_group
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator, self.token_to_kv_pool_host
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
load_cache_event=self.load_cache_event,
)
# record the nodes with ongoing write through
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def write_backup(self, node: TreeNode):
host_indices = self.cache_controller.write(
device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id,
)
if host_indices is None:
self.evict_host(len(node.value))
host_indices = self.cache_controller.write(
device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id,
)
if host_indices is not None:
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node.hit_count = 0
def writing_check(self):
while not self.cache_controller.ack_write_queue.empty():
try:
ack_id = self.cache_controller.ack_write_queue.get_nowait()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
# clear the reference
del self.ongoing_write_through[ack_id]
except Exception:
break
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
del self.ongoing_write_through[ack_id]
def loading_check(self):
while not self.cache_controller.ack_load_queue.empty():
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break
def evictable_size(self):
self.writing_check()
self.loading_check()
return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None):
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return device_indices
def loading_complete(self, node: TreeNode):
self.loading_check()
return node.loading == False
def init_load_back(
self,
last_node: TreeNode,
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices
def read_to_load_cache(self):
self.load_cache_event.set()
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
if self.disable:
return [], self.root_node
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
last_node_global = last_node
while last_node.evicted:
last_node = last_node.parent
if include_evicted:
return value, last_node, last_node_global
else:
return value, last_node
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
value = []

View File

@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self.layer_num = layer_num
self._create_buffers()
self.layer_transfer_counter = None
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data