From 10b544ae9b426c0b081cf06e5fcd1f24f82d7443 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 12 Mar 2025 11:22:35 -0700 Subject: [PATCH] Hierarchical Caching Refactoring and Fixing TP issue (#4082) --- .../sglang/srt/managers/cache_controller.py | 65 ++++++++++++++++++- python/sglang/srt/managers/schedule_batch.py | 20 ++++-- python/sglang/srt/managers/schedule_policy.py | 35 ++++++++-- python/sglang/srt/managers/scheduler.py | 47 ++++++-------- python/sglang/srt/mem_cache/hiradix_cache.py | 62 +++++++++++++----- python/sglang/srt/mem_cache/memory_pool.py | 21 ++++++ 6 files changed, 194 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 003836d81..703c84369 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import ( logger = logging.getLogger(__name__) +class LayerDoneCounter: + def __init__(self, num_layers): + self.counter = num_layers + self.condition = threading.Condition() + + def increment(self): + with self.condition: + self.counter += 1 + self.condition.notify_all() + + def wait_until(self, threshold): + with self.condition: + while self.counter <= threshold: + self.condition.wait() + + def reset(self): + with self.condition: + self.counter = 0 + + class CacheOperation: counter = 0 @@ -132,6 +152,7 @@ class HiCacheController: self, token_to_kv_pool_allocator: TokenToKVPoolAllocator, mem_pool_host: MHATokenToKVPoolHost, + load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", ): self.mem_pool_device_allocator = token_to_kv_pool_allocator @@ -139,6 +160,10 @@ class HiCacheController: self.mem_pool_host = mem_pool_host self.write_policy = write_policy + self.load_cache_event = load_cache_event + self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) + self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) + if write_policy not in [ "write_through", "write_through_selective", @@ -165,7 +190,7 @@ class HiCacheController: target=self.write_thread_func_buffer, daemon=True ) self.load_thread = threading.Thread( - target=self.load_thread_func_buffer, daemon=True + target=self.load_thread_func_layer_by_layer, daemon=True ) self.write_thread.start() self.load_thread.start() @@ -186,7 +211,7 @@ class HiCacheController: target=self.write_thread_func_buffer, daemon=True ) self.load_thread = threading.Thread( - target=self.load_thread_func_buffer, daemon=True + target=self.load_thread_func_layer_by_layer, daemon=True ) self.stop_event.clear() self.write_thread.start() @@ -273,6 +298,42 @@ class HiCacheController: except Exception as e: logger.error(e) + def load_thread_func_layer_by_layer(self): + """ + Load KV caches from host memory to device memory layer by layer. + """ + with torch.cuda.stream(self.load_stream): + while not self.stop_event.is_set(): + self.load_cache_event.wait(timeout=1) + if not self.load_cache_event.is_set(): + continue + self.load_cache_event.clear() + + batch_operation = None + while self.load_queue.qsize() > 0: + op = self.load_queue.get(block=True) + if batch_operation is None: + batch_operation = op + else: + batch_operation.merge(op) + if batch_operation is None: + continue + + self.layer_done_counter.reset() + for i in range(self.mem_pool_host.layer_num): + flat_data = self.mem_pool_host.get_flat_data_by_layer( + batch_operation.host_indices, i + ) + self.mem_pool_device.transfer_per_layer( + batch_operation.device_indices, flat_data, i + ) + self.layer_done_counter.increment() + + self.mem_pool_host.complete_io(batch_operation.host_indices) + for node_id in batch_operation.node_ids: + if node_id != 0: + self.ack_load_queue.put(node_id) + def write_aux_func(self, no_wait=False): """ Auxiliary function to prepare the buffer for write operations. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 219b9d145..0ac870767 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -315,6 +315,7 @@ class Req: # The relative logprob_start_len in an extend batch self.extend_logprob_start_len = 0 self.last_node = None + self.last_node_global = None # Whether or not if it is chunked. It increments whenever # it is chunked, and decrement whenever chunked request is @@ -389,13 +390,24 @@ class Req: # Whether request reached finished condition return self.finished_reason is not None - def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): + def init_next_round_input( + self, + tree_cache: Optional[BasePrefixCache] = None, + enable_hierarchical_cache=False, + ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: # tree cache is None if the prefix is not computed with tree cache. - self.prefix_indices, self.last_node = tree_cache.match_prefix( - rid=self.rid, key=self.adjust_max_prefix_ids() - ) + if enable_hierarchical_cache: + self.prefix_indices, self.last_node, self.last_node_global = ( + tree_cache.match_prefix( + key=self.adjust_max_prefix_ids(), include_evicted=True + ) + ) + else: + self.prefix_indices, self.last_node = tree_cache.match_prefix( + rid=self.rid, key=self.adjust_max_prefix_ids() + ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index de43c98f9..3f569088b 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum): class SchedulePolicy: Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] - def __init__(self, policy: str, tree_cache: BasePrefixCache): + def __init__( + self, + policy: str, + tree_cache: BasePrefixCache, + enable_hierarchical_cache: bool = False, + ): self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.tree_cache = tree_cache + self.enable_hierarchical_cache = enable_hierarchical_cache # It is used to find the matching prefix for in-batch prefix caching. self.waiting_queue_radix_tree = RadixCache( @@ -149,9 +155,14 @@ class SchedulePolicy: prefix_ids = r.adjust_max_prefix_ids() # NOTE: the prefix_indices must always be aligned with last_node - r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=prefix_ids - ) + if self.enable_hierarchical_cache: + r.prefix_indices, r.last_node, r.last_node_global = ( + self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True) + ) + else: + r.prefix_indices, r.last_node = self.tree_cache.match_prefix( + rid=r.rid, key=prefix_ids + ) # NOTE(sang): This logic is for in-batch prefix caching; # If there are more than 1 request that have small matching prefix from @@ -428,7 +439,9 @@ class PrefillAdder: return self.budget_state() - def add_one_req(self, req: Req, has_chunked_req: bool): + def add_one_req( + self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False + ): if req.sampling_params.ignore_eos and self.tree_cache.disable: return self.add_one_req_ignore_eos(req, has_chunked_req) @@ -448,6 +461,18 @@ class PrefillAdder: if total_tokens > self.rem_total_tokens: return AddReqResult.NO_TOKEN + if ( + enable_hierarchical_cache + and req.last_node_global is not None + and req.last_node_global.evicted + ): + req.last_node, req.prefix_indices = self.tree_cache.init_load_back( + req.last_node_global, req.prefix_indices + ) + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + input_tokens = req.extend_input_len + prefix_len = len(req.prefix_indices) + if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: # Non-chunked prefill self.can_run_list.append(req) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a5c6a1dbd..af0bb825f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -265,12 +265,10 @@ class Scheduler: f"context_len={self.model_config.context_len}" ) - # Init memory pool and cache self.init_memory_pool_and_cache() # Init running status self.waiting_queue: List[Req] = [] - self.staging_reqs = {} # The running decoding batch for continuous batching self.running_batch: Optional[ScheduleBatch] = None # The current forward batch @@ -308,7 +306,9 @@ class Scheduler: self.grammar_backend = None # Init schedule policy and new token estimation - self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) + self.policy = SchedulePolicy( + self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache + ) assert ( server_args.schedule_conservativeness >= 0 ), "Invalid schedule_conservativeness" @@ -431,6 +431,7 @@ class Scheduler: self.tree_cache = HiRadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + tp_cache_group=self.tp_worker.get_tp_cpu_group(), ) else: self.tree_cache = RadixCache( @@ -1005,6 +1006,11 @@ class Scheduler: self.batch_is_full = True return None + if self.enable_hierarchical_cache: + # check for completion of hierarchical cache activities to release memory + self.tree_cache.writing_check() + self.tree_cache.loading_check() + # Get priority queue prefix_computed = self.policy.calc_priority(self.waiting_queue) @@ -1048,32 +1054,14 @@ class Scheduler: self.batch_is_full = True break - req.init_next_round_input(None if prefix_computed else self.tree_cache) + req.init_next_round_input( + None if prefix_computed else self.tree_cache, + self.enable_hierarchical_cache, + ) - if self.enable_hierarchical_cache and req.last_node is not None: - if req.last_node.evicted: - # loading KV cache for the request - req.last_node, req.prefix_indices = self.tree_cache.init_load_back( - req.last_node, - req.prefix_indices, - adder.rem_total_tokens, - ) - if req.last_node.loading: - # to prevent frequent cache invalidation - if req.rid in self.staging_reqs: - self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid]) - self.tree_cache.inc_lock_ref(req.last_node) - self.staging_reqs[req.rid] = req.last_node - continue - elif req.last_node.loading: - if not self.tree_cache.loading_complete(req.last_node): - continue - - if req.rid in self.staging_reqs: - self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid]) - del self.staging_reqs[req.rid] - - res = adder.add_one_req(req, self.chunked_req) + res = adder.add_one_req( + req, self.chunked_req, self.enable_hierarchical_cache + ) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: if self.enable_hierarchical_cache: @@ -1094,6 +1082,9 @@ class Scheduler: x for x in self.waiting_queue if x not in set(can_run_list) ] + if self.enable_hierarchical_cache: + self.tree_cache.read_to_load_cache() + if adder.new_chunked_req is not None: assert self.chunked_req is None self.chunked_req = adder.new_chunked_req diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 28bab2869..f629bb751 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -1,11 +1,13 @@ import heapq import logging +import threading import time from typing import List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController +from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPoolHost, ReqToTokenPool, @@ -22,12 +24,18 @@ class HiRadixCache(RadixCache): self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, + tp_cache_group: torch.distributed.ProcessGroup, ): self.token_to_kv_pool_host = MHATokenToKVPoolHost( token_to_kv_pool_allocator.get_kvcache() ) + self.tp_group = tp_cache_group + + self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( - token_to_kv_pool_allocator, self.token_to_kv_pool_host + token_to_kv_pool_allocator, + self.token_to_kv_pool_host, + load_cache_event=self.load_cache_event, ) # record the nodes with ongoing write through @@ -55,14 +63,12 @@ class HiRadixCache(RadixCache): def write_backup(self, node: TreeNode): host_indices = self.cache_controller.write( device_indices=node.value, - priority=-self.get_height(node), node_id=node.id, ) if host_indices is None: self.evict_host(len(node.value)) host_indices = self.cache_controller.write( device_indices=node.value, - priority=-self.get_height(node), node_id=node.id, ) if host_indices is not None: @@ -83,14 +89,20 @@ class HiRadixCache(RadixCache): node.hit_count = 0 def writing_check(self): - while not self.cache_controller.ack_write_queue.empty(): - try: - ack_id = self.cache_controller.ack_write_queue.get_nowait() - self.dec_lock_ref(self.ongoing_write_through[ack_id]) - # clear the reference - del self.ongoing_write_through[ack_id] - except Exception: - break + queue_size = torch.tensor( + self.cache_controller.ack_write_queue.qsize(), dtype=torch.int + ) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to radix cache + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + for _ in range(queue_size.item()): + ack_id = self.cache_controller.ack_write_queue.get() + self.dec_lock_ref(self.ongoing_write_through[ack_id]) + del self.ongoing_write_through[ack_id] def loading_check(self): while not self.cache_controller.ack_load_queue.empty(): @@ -108,8 +120,6 @@ class HiRadixCache(RadixCache): break def evictable_size(self): - self.writing_check() - self.loading_check() return self.evictable_size_ def evict(self, num_tokens: int, evict_callback=None): @@ -242,10 +252,6 @@ class HiRadixCache(RadixCache): return device_indices - def loading_complete(self, node: TreeNode): - self.loading_check() - return node.loading == False - def init_load_back( self, last_node: TreeNode, @@ -272,6 +278,28 @@ class HiRadixCache(RadixCache): return last_node, prefix_indices + def read_to_load_cache(self): + self.load_cache_event.set() + + def match_prefix(self, key: List[int], include_evicted=False, **kwargs): + if self.disable: + return [], self.root_node + + value, last_node = self._match_prefix_helper(self.root_node, key) + if value: + value = torch.concat(value) + else: + value = torch.tensor([], dtype=torch.int32) + + last_node_global = last_node + while last_node.evicted: + last_node = last_node.parent + + if include_evicted: + return value, last_node, last_node_global + else: + return value, last_node + def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() value = [] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 36a1bd8d6..4dfb72bca 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache): self.layer_num = layer_num self._create_buffers() + self.layer_transfer_counter = None + k_size, v_size = self.get_kv_size_bytes() logger.info( f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" @@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache): self.k_buffer[i][indices] = k_data[i] self.v_buffer[i][indices] = v_data[i] + def register_layer_transfer_counter(self, layer_transfer_counter): + self.layer_transfer_counter = layer_transfer_counter + + def transfer_per_layer(self, indices, flat_data, layer_id): + # transfer prepared data from host to device + flat_data = flat_data.to(device=self.device, non_blocking=False) + k_data, v_data = flat_data[0], flat_data[1] + self.k_buffer[layer_id][indices] = k_data + self.v_buffer[layer_id][indices] = v_data + def get_key_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id) + if self.store_dtype != self.dtype: return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id) + if self.store_dtype != self.dtype: return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id] @@ -530,6 +548,9 @@ class MHATokenToKVPoolHost: def get_flat_data(self, indices): return self.kv_buffer[:, :, indices] + def get_flat_data_by_layer(self, indices, layer_id): + return self.kv_buffer[:, layer_id, indices] + def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, :, indices] = flat_data