Hot fix for hicache with new page aligned radixtree (#4397)
This commit is contained in:
@@ -248,6 +248,8 @@ class HiCacheController:
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
# to ensure the device indices are ready before accessed by another CUDA stream
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.load_queue.put(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
|
||||
@@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
||||
page_size=self.page_size,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
|
||||
@@ -25,11 +25,17 @@ class HiRadixCache(RadixCache):
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
page_size: int,
|
||||
):
|
||||
if page_size != 1:
|
||||
raise ValueError(
|
||||
"Page size larger than 1 is not yet supported in HiRadixCache."
|
||||
)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.tp_group = tp_cache_group
|
||||
self.page_size = page_size
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
@@ -45,7 +51,9 @@ class HiRadixCache(RadixCache):
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = 1
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
TreeNode.counter = 0
|
||||
|
||||
Reference in New Issue
Block a user