Upstreaming hicache bug fixes (#7267)
This commit is contained in:
@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LayerDoneCounter:
|
||||
def __init__(self, num_layers):
|
||||
self.counter = num_layers
|
||||
self.condition = threading.Condition()
|
||||
self.num_layers = num_layers
|
||||
# extra producer and consumer counters for overlap mode
|
||||
self.num_counters = 3
|
||||
self.counters = [num_layers] * self.num_counters
|
||||
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
|
||||
self.producer_index = 0
|
||||
self.consumer_index = 0
|
||||
|
||||
def next_producer(self):
|
||||
return (self.producer_index + 1) % self.num_counters
|
||||
|
||||
def update_producer(self):
|
||||
self.producer_index = self.next_producer()
|
||||
return self.producer_index
|
||||
|
||||
def set_consumer(self, index):
|
||||
self.consumer_index = index
|
||||
|
||||
def increment(self):
|
||||
with self.condition:
|
||||
self.counter += 1
|
||||
self.condition.notify_all()
|
||||
with self.conditions[self.producer_index]:
|
||||
self.counters[self.producer_index] += 1
|
||||
self.conditions[self.producer_index].notify_all()
|
||||
|
||||
def wait_until(self, threshold):
|
||||
with self.condition:
|
||||
while self.counter <= threshold:
|
||||
self.condition.wait()
|
||||
with self.conditions[self.consumer_index]:
|
||||
while self.counters[self.consumer_index] <= threshold:
|
||||
self.conditions[self.consumer_index].wait()
|
||||
|
||||
def reset(self):
|
||||
with self.condition:
|
||||
self.counter = 0
|
||||
with self.conditions[self.producer_index]:
|
||||
self.counters[self.producer_index] = 0
|
||||
|
||||
|
||||
class CacheOperation:
|
||||
@@ -296,7 +311,6 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
# time.sleep(18e-6 * len(operation.host_indices))
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
@@ -320,6 +334,7 @@ class HiCacheController:
|
||||
if not self.load_cache_event.is_set():
|
||||
continue
|
||||
self.load_cache_event.clear()
|
||||
self.layer_done_counter.update_producer()
|
||||
|
||||
batch_operation = None
|
||||
while self.load_queue.qsize() > 0:
|
||||
@@ -331,6 +346,7 @@ class HiCacheController:
|
||||
if batch_operation is None:
|
||||
continue
|
||||
|
||||
# start layer-wise KV cache transfer from CPU to GPU
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
if self.page_size == 1:
|
||||
@@ -466,6 +482,7 @@ class HiCacheController:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
# todo (zhiqiang): double buffering to be deprecated
|
||||
def write_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
@@ -659,14 +659,6 @@ class Req:
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
elif enable_hierarchical_cache:
|
||||
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
||||
while self.last_node.evicted:
|
||||
self.prefix_indices = self.prefix_indices[
|
||||
: -len(self.last_node.host_value)
|
||||
]
|
||||
self.last_node = self.last_node.parent
|
||||
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# hicache pointer for synchronizing data loading from CPU to GPU
|
||||
hicache_consumer_index: int = 0
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
token_type_ids=self.token_type_ids,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
spec_info=self.spec_info,
|
||||
hicache_consumer_index=self.hicache_consumer_index,
|
||||
capture_hidden_mode=(
|
||||
CaptureHiddenMode.FULL
|
||||
if self.return_hidden_states
|
||||
@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
spec_num_draft_tokens: Optional[int] = None
|
||||
hicache_consumer_index: int = 0
|
||||
|
||||
# Overlap event
|
||||
launch_done: Optional[threading.Event] = None
|
||||
|
||||
@@ -565,6 +565,10 @@ class Scheduler(
|
||||
hicache_size=server_args.hicache_size,
|
||||
hicache_write_policy=server_args.hicache_write_policy,
|
||||
)
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
)
|
||||
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
@@ -1514,8 +1518,13 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
# bypass prefix_computed if enable_hierarchical_cache
|
||||
req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache,
|
||||
(
|
||||
None
|
||||
if (prefix_computed and not self.enable_hierarchical_cache)
|
||||
else self.tree_cache
|
||||
),
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
|
||||
@@ -1548,9 +1557,6 @@ class Scheduler(
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache.ready_to_load_cache()
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
self.chunked_req = adder.new_chunked_req
|
||||
@@ -1574,6 +1580,10 @@ class Scheduler(
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
chunked_req=self.chunked_req,
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
||||
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
|
||||
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
@@ -1649,6 +1659,11 @@ class Scheduler(
|
||||
if self.is_generation:
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.tp_worker.set_hicache_consumer(
|
||||
model_worker_batch.hicache_consumer_index
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
|
||||
@@ -147,6 +147,15 @@ class TpModelWorker:
|
||||
# A reference make this class has the same member as TpModelWorkerClient
|
||||
self.worker = self
|
||||
|
||||
self.hicache_layer_transfer_counter = None
|
||||
|
||||
def register_hicache_layer_transfer_counter(self, counter):
|
||||
self.hicache_layer_transfer_counter = counter
|
||||
|
||||
def set_hicache_consumer(self, consumer_index):
|
||||
if self.hicache_layer_transfer_counter is not None:
|
||||
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
||||
|
||||
def get_worker_info(self):
|
||||
return (
|
||||
self.max_total_num_tokens,
|
||||
|
||||
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
|
||||
if self.device == "cpu":
|
||||
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
||||
|
||||
self.hicache_layer_transfer_counter = None
|
||||
|
||||
def register_hicache_layer_transfer_counter(self, counter):
|
||||
self.hicache_layer_transfer_counter = counter
|
||||
|
||||
def set_hicache_consumer(self, consumer_index):
|
||||
if self.hicache_layer_transfer_counter is not None:
|
||||
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
||||
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
|
||||
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
|
||||
input_ids = model_worker_batch.input_ids
|
||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||
# Run forward
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.worker.forward_batch_generation(
|
||||
|
||||
@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
|
||||
return last_node, prefix_indices
|
||||
|
||||
def ready_to_load_cache(self):
|
||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||
self.load_cache_event.set()
|
||||
return producer_index
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||
@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
new_node.loading = child.loading
|
||||
new_node.hit_count = child.hit_count
|
||||
|
||||
# split value and host value if exists
|
||||
if child.evicted:
|
||||
|
||||
Reference in New Issue
Block a user