diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index a2a88b634..5b8d706a3 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -239,7 +239,7 @@ class WorkloadGenerator: tokenizer=self.tokenizer, dataset_path=args.dataset_path, ) - self.candidate_inputs = [i[0] for i in self.candidate_inputs] + self.candidate_inputs = [i.prompt for i in self.candidate_inputs] init_requests = [ (i, gen_payload(self.candidate_inputs[i], args.output_length)) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 9f7f48cda..0fd102b6b 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -30,22 +30,37 @@ logger = logging.getLogger(__name__) class LayerDoneCounter: def __init__(self, num_layers): - self.counter = num_layers - self.condition = threading.Condition() + self.num_layers = num_layers + # extra producer and consumer counters for overlap mode + self.num_counters = 3 + self.counters = [num_layers] * self.num_counters + self.conditions = [threading.Condition() for _ in range(self.num_counters)] + self.producer_index = 0 + self.consumer_index = 0 + + def next_producer(self): + return (self.producer_index + 1) % self.num_counters + + def update_producer(self): + self.producer_index = self.next_producer() + return self.producer_index + + def set_consumer(self, index): + self.consumer_index = index def increment(self): - with self.condition: - self.counter += 1 - self.condition.notify_all() + with self.conditions[self.producer_index]: + self.counters[self.producer_index] += 1 + self.conditions[self.producer_index].notify_all() def wait_until(self, threshold): - with self.condition: - while self.counter <= threshold: - self.condition.wait() + with self.conditions[self.consumer_index]: + while self.counters[self.consumer_index] <= threshold: + self.conditions[self.consumer_index].wait() def reset(self): - with self.condition: - self.counter = 0 + with self.conditions[self.producer_index]: + self.counters[self.producer_index] = 0 class CacheOperation: @@ -296,7 +311,6 @@ class HiCacheController: while not self.stop_event.is_set(): try: operation = self.load_queue.get(block=True, timeout=1) - # time.sleep(18e-6 * len(operation.host_indices)) operation.data = self.mem_pool_host.get_flat_data( operation.host_indices ) @@ -320,6 +334,7 @@ class HiCacheController: if not self.load_cache_event.is_set(): continue self.load_cache_event.clear() + self.layer_done_counter.update_producer() batch_operation = None while self.load_queue.qsize() > 0: @@ -331,6 +346,7 @@ class HiCacheController: if batch_operation is None: continue + # start layer-wise KV cache transfer from CPU to GPU self.layer_done_counter.reset() for i in range(self.mem_pool_host.layer_num): if self.page_size == 1: @@ -466,6 +482,7 @@ class HiCacheController: except Exception as e: logger.error(e) + # todo (zhiqiang): double buffering to be deprecated def write_thread_func_buffer(self): aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) aux_thread.start() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 670293a5f..28e1e33b8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -659,14 +659,6 @@ class Req: self.prefix_indices, self.last_node = tree_cache.match_prefix( rid=self.rid, key=self.adjust_max_prefix_ids() ) - elif enable_hierarchical_cache: - # in case last_node is evicted during scheduling, we need to update the prefix_indices - while self.last_node.evicted: - self.prefix_indices = self.prefix_indices[ - : -len(self.last_node.host_value) - ] - self.last_node = self.last_node.parent - self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): @@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return hidden states return_hidden_states: bool = False + # hicache pointer for synchronizing data loading from CPU to GPU + hicache_consumer_index: int = 0 + @classmethod def init_new( cls, @@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): token_type_ids=self.token_type_ids, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, + hicache_consumer_index=self.hicache_consumer_index, capture_hidden_mode=( CaptureHiddenMode.FULL if self.return_hidden_states @@ -1839,6 +1835,7 @@ class ModelWorkerBatch: # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None spec_num_draft_tokens: Optional[int] = None + hicache_consumer_index: int = 0 # Overlap event launch_done: Optional[threading.Event] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5555747cd..57627e55c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -565,6 +565,10 @@ class Scheduler( hicache_size=server_args.hicache_size, hicache_write_policy=server_args.hicache_write_policy, ) + self.tp_worker.register_hicache_layer_transfer_counter( + self.tree_cache.cache_controller.layer_done_counter + ) + else: self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, @@ -1514,8 +1518,13 @@ class Scheduler( self.running_batch.batch_is_full = True break + # bypass prefix_computed if enable_hierarchical_cache req.init_next_round_input( - None if prefix_computed else self.tree_cache, + ( + None + if (prefix_computed and not self.enable_hierarchical_cache) + else self.tree_cache + ), self.enable_hierarchical_cache, ) @@ -1548,9 +1557,6 @@ class Scheduler( x for x in self.waiting_queue if x not in set(can_run_list) ] - if self.enable_hierarchical_cache: - self.tree_cache.ready_to_load_cache() - if adder.new_chunked_req is not None: assert self.chunked_req is None self.chunked_req = adder.new_chunked_req @@ -1574,6 +1580,10 @@ class Scheduler( self.server_args.enable_custom_logit_processor, chunked_req=self.chunked_req, ) + if self.enable_hierarchical_cache: + # todo (zhiqiang): disable cuda graph execution if hicache loading triggered + new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache() + new_batch.prepare_for_extend() # Mixed-style chunked prefill @@ -1649,6 +1659,11 @@ class Scheduler( if self.is_generation: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() + + # update the consumer index of hicache to the running batch + self.tp_worker.set_hicache_consumer( + model_worker_batch.hicache_consumer_index + ) if self.pp_group.is_last_rank: logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 786a34a1e..88bbde1b6 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -147,6 +147,15 @@ class TpModelWorker: # A reference make this class has the same member as TpModelWorkerClient self.worker = self + self.hicache_layer_transfer_counter = None + + def register_hicache_layer_transfer_counter(self, counter): + self.hicache_layer_transfer_counter = counter + + def set_hicache_consumer(self, consumer_index): + if self.hicache_layer_transfer_counter is not None: + self.hicache_layer_transfer_counter.set_consumer(consumer_index) + def get_worker_info(self): return ( self.max_total_num_tokens, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 783d864ea..45f220db6 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -88,6 +88,15 @@ class TpModelWorkerClient: if self.device == "cpu": self.scheduler_stream.synchronize = lambda: None # No-op for CPU + self.hicache_layer_transfer_counter = None + + def register_hicache_layer_transfer_counter(self, counter): + self.hicache_layer_transfer_counter = counter + + def set_hicache_consumer(self, consumer_index): + if self.hicache_layer_transfer_counter is not None: + self.hicache_layer_transfer_counter.set_consumer(consumer_index) + def get_worker_info(self): return self.worker.get_worker_info() @@ -146,6 +155,8 @@ class TpModelWorkerClient: input_ids = model_worker_batch.input_ids resolve_future_token_ids(input_ids, self.future_token_ids_map) + # update the consumer index of hicache to the running batch + self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) # Run forward logits_output, next_token_ids, can_run_cuda_graph = ( self.worker.forward_batch_generation( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 03a4417cf..4d6c0ae11 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -307,7 +307,9 @@ class HiRadixCache(RadixCache): return last_node, prefix_indices def ready_to_load_cache(self): + producer_index = self.cache_controller.layer_done_counter.next_producer() self.load_cache_event.set() + return producer_index def match_prefix(self, key: List[int], include_evicted=False, **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) @@ -372,6 +374,7 @@ class HiRadixCache(RadixCache): new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] new_node.loading = child.loading + new_node.hit_count = child.hit_count # split value and host value if exists if child.evicted: