diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 703c84369..f8857fab4 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -248,6 +248,8 @@ class HiCacheController: if device_indices is None: return None self.mem_pool_host.protect_load(host_indices) + # to ensure the device indices are ready before accessed by another CUDA stream + torch.cuda.current_stream().synchronize() self.load_queue.put( CacheOperation(host_indices, device_indices, node_id, priority) ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c91905f5f..d5ce3bc71 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin): req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, tp_cache_group=self.tp_worker.get_tp_cpu_group(), + page_size=self.page_size, ) else: self.tree_cache = RadixCache( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 6b4825994..754a88c07 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -25,11 +25,17 @@ class HiRadixCache(RadixCache): req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, tp_cache_group: torch.distributed.ProcessGroup, + page_size: int, ): + if page_size != 1: + raise ValueError( + "Page size larger than 1 is not yet supported in HiRadixCache." + ) self.token_to_kv_pool_host = MHATokenToKVPoolHost( token_to_kv_pool_allocator.get_kvcache() ) self.tp_group = tp_cache_group + self.page_size = page_size self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( @@ -45,7 +51,9 @@ class HiRadixCache(RadixCache): # todo: dynamically adjust the threshold self.write_through_threshold = 1 self.load_back_threshold = 10 - super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False) + super().__init__( + req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False + ) def reset(self): TreeNode.counter = 0