Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.tp_group = tp_cache_group
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator, self.token_to_kv_pool_host
|
||||
token_to_kv_pool_allocator,
|
||||
self.token_to_kv_pool_host,
|
||||
load_cache_event=self.load_cache_event,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
|
||||
def write_backup(self, node: TreeNode):
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is None:
|
||||
self.evict_host(len(node.value))
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is not None:
|
||||
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
|
||||
node.hit_count = 0
|
||||
|
||||
def writing_check(self):
|
||||
while not self.cache_controller.ack_write_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_write_queue.get_nowait()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
# clear the reference
|
||||
del self.ongoing_write_through[ack_id]
|
||||
except Exception:
|
||||
break
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to radix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
del self.ongoing_write_through[ack_id]
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
|
||||
break
|
||||
|
||||
def evictable_size(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
return self.evictable_size_
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback=None):
|
||||
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return device_indices
|
||||
|
||||
def loading_complete(self, node: TreeNode):
|
||||
self.loading_check()
|
||||
return node.loading == False
|
||||
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def read_to_load_cache(self):
|
||||
self.load_cache_event.set()
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
last_node_global = last_node
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
if include_evicted:
|
||||
return value, last_node, last_node_global
|
||||
else:
|
||||
return value, last_node
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
value = []
|
||||
|
||||
Reference in New Issue
Block a user