Upstreaming hicache bug fixes (#7267)
This commit is contained in:
@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LayerDoneCounter:
|
||||
def __init__(self, num_layers):
|
||||
self.counter = num_layers
|
||||
self.condition = threading.Condition()
|
||||
self.num_layers = num_layers
|
||||
# extra producer and consumer counters for overlap mode
|
||||
self.num_counters = 3
|
||||
self.counters = [num_layers] * self.num_counters
|
||||
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
|
||||
self.producer_index = 0
|
||||
self.consumer_index = 0
|
||||
|
||||
def next_producer(self):
|
||||
return (self.producer_index + 1) % self.num_counters
|
||||
|
||||
def update_producer(self):
|
||||
self.producer_index = self.next_producer()
|
||||
return self.producer_index
|
||||
|
||||
def set_consumer(self, index):
|
||||
self.consumer_index = index
|
||||
|
||||
def increment(self):
|
||||
with self.condition:
|
||||
self.counter += 1
|
||||
self.condition.notify_all()
|
||||
with self.conditions[self.producer_index]:
|
||||
self.counters[self.producer_index] += 1
|
||||
self.conditions[self.producer_index].notify_all()
|
||||
|
||||
def wait_until(self, threshold):
|
||||
with self.condition:
|
||||
while self.counter <= threshold:
|
||||
self.condition.wait()
|
||||
with self.conditions[self.consumer_index]:
|
||||
while self.counters[self.consumer_index] <= threshold:
|
||||
self.conditions[self.consumer_index].wait()
|
||||
|
||||
def reset(self):
|
||||
with self.condition:
|
||||
self.counter = 0
|
||||
with self.conditions[self.producer_index]:
|
||||
self.counters[self.producer_index] = 0
|
||||
|
||||
|
||||
class CacheOperation:
|
||||
@@ -296,7 +311,6 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
# time.sleep(18e-6 * len(operation.host_indices))
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
@@ -320,6 +334,7 @@ class HiCacheController:
|
||||
if not self.load_cache_event.is_set():
|
||||
continue
|
||||
self.load_cache_event.clear()
|
||||
self.layer_done_counter.update_producer()
|
||||
|
||||
batch_operation = None
|
||||
while self.load_queue.qsize() > 0:
|
||||
@@ -331,6 +346,7 @@ class HiCacheController:
|
||||
if batch_operation is None:
|
||||
continue
|
||||
|
||||
# start layer-wise KV cache transfer from CPU to GPU
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
if self.page_size == 1:
|
||||
@@ -466,6 +482,7 @@ class HiCacheController:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
# todo (zhiqiang): double buffering to be deprecated
|
||||
def write_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
Reference in New Issue
Block a user