Upstreaming hicache bug fixes (#7267)

This commit is contained in:
Zhiqiang Xie
2025-06-17 17:44:57 -07:00
committed by GitHub
parent c26d7349d3
commit e56685ac1b
7 changed files with 76 additions and 24 deletions

View File

@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()
self.num_layers = num_layers
# extra producer and consumer counters for overlap mode
self.num_counters = 3
self.counters = [num_layers] * self.num_counters
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
self.producer_index = 0
self.consumer_index = 0
def next_producer(self):
return (self.producer_index + 1) % self.num_counters
def update_producer(self):
self.producer_index = self.next_producer()
return self.producer_index
def set_consumer(self, index):
self.consumer_index = index
def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()
with self.conditions[self.producer_index]:
self.counters[self.producer_index] += 1
self.conditions[self.producer_index].notify_all()
def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()
with self.conditions[self.consumer_index]:
while self.counters[self.consumer_index] <= threshold:
self.conditions[self.consumer_index].wait()
def reset(self):
with self.condition:
self.counter = 0
with self.conditions[self.producer_index]:
self.counters[self.producer_index] = 0
class CacheOperation:
@@ -296,7 +311,6 @@ class HiCacheController:
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
@@ -320,6 +334,7 @@ class HiCacheController:
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
self.layer_done_counter.update_producer()
batch_operation = None
while self.load_queue.qsize() > 0:
@@ -331,6 +346,7 @@ class HiCacheController:
if batch_operation is None:
continue
# start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1:
@@ -466,6 +482,7 @@ class HiCacheController:
except Exception as e:
logger.error(e)
# todo (zhiqiang): double buffering to be deprecated
def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start()