[Refactor] Remove Hicache Load & Write threads (#10127)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
|
||||
if write_back:
|
||||
# blocking till all write back complete
|
||||
while len(self.ongoing_write_through) > 0:
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
del self.ongoing_write_through[ack_id]
|
||||
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
||||
finish_event.synchronize()
|
||||
for ack_id in ack_list:
|
||||
del self.ongoing_write_through[ack_id]
|
||||
self.cache_controller.ack_write_queue.clear()
|
||||
assert len(self.ongoing_write_through) == 0
|
||||
return
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
|
||||
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
||||
if len(self.ongoing_write_through) == 0:
|
||||
return
|
||||
|
||||
finish_count = 0
|
||||
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
||||
if not finish_event.query():
|
||||
break
|
||||
finish_count += 1
|
||||
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to radix cache
|
||||
# synchronize TP workers to make the same update to radix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
backuped_node = self.ongoing_write_through[ack_id]
|
||||
self.dec_lock_ref(backuped_node)
|
||||
del self.ongoing_write_through[ack_id]
|
||||
if self.enable_storage:
|
||||
self.write_backup_storage(backuped_node)
|
||||
|
||||
finish_count = int(queue_size.item())
|
||||
while finish_count > 0:
|
||||
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
||||
finish_event.synchronize()
|
||||
for ack_id in ack_list:
|
||||
backuped_node = self.ongoing_write_through.pop(ack_id)
|
||||
self.dec_lock_ref(backuped_node)
|
||||
if self.enable_storage:
|
||||
self.write_backup_storage(backuped_node)
|
||||
finish_count -= 1
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_load_queue.get_nowait()
|
||||
start_node, end_node = self.ongoing_load_back[ack_id]
|
||||
self.dec_lock_ref(end_node)
|
||||
while end_node != start_node:
|
||||
assert end_node.loading
|
||||
end_node.loading = False
|
||||
end_node = end_node.parent
|
||||
# clear the reference
|
||||
del self.ongoing_load_back[ack_id]
|
||||
except Exception:
|
||||
finish_count = 0
|
||||
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
||||
if not finish_event.query():
|
||||
# the KV cache loading is still ongoing
|
||||
break
|
||||
finish_count += 1
|
||||
# no need to sync across TP workers as batch forwarding is synced
|
||||
for ack_id in ack_list:
|
||||
end_node = self.ongoing_load_back.pop(ack_id)
|
||||
self.dec_lock_ref(end_node)
|
||||
|
||||
# ACK until all events are processed
|
||||
del self.cache_controller.ack_load_queue[:finish_count]
|
||||
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
|
||||
# no sufficient GPU memory to load back KV caches
|
||||
return None
|
||||
|
||||
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
|
||||
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
||||
offset = 0
|
||||
for node in nodes_to_load:
|
||||
node.value = device_indices[offset : offset + len(node.host_value)]
|
||||
offset += len(node.host_value)
|
||||
node.loading = True
|
||||
self.evictable_size_ += len(device_indices)
|
||||
self.inc_lock_ref(last_hit_node)
|
||||
|
||||
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
|
||||
last_node,
|
||||
)
|
||||
|
||||
def ready_to_load_host_cache(self):
|
||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||
self.load_cache_event.set()
|
||||
return producer_index
|
||||
def ready_to_load_host_cache(self) -> int:
|
||||
"""
|
||||
Notify the cache controller to start the KV cache loading.
|
||||
Return the consumer index for the schedule batch manager to track.
|
||||
"""
|
||||
return self.cache_controller.start_loading()
|
||||
|
||||
def check_hicache_events(self):
|
||||
self.writing_check()
|
||||
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
new_node.loading = child.loading
|
||||
new_node.hit_count = child.hit_count
|
||||
|
||||
# split value and host value if exists
|
||||
|
||||
Reference in New Issue
Block a user