diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 6846022f9..f9d45b2f7 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -18,7 +18,7 @@ import math import threading import time from queue import Empty, Full, PriorityQueue, Queue -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple import torch @@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool logger = logging.getLogger(__name__) +class LayerLoadingEvent: + def __init__(self, num_layers: int): + self._num_layers = num_layers + self.load_events = [torch.cuda.Event() for _ in range(num_layers)] + self.start_event = torch.cuda.Event() # start event on controller stream + + def complete(self, layer_index: int): + assert 0 <= layer_index < self._num_layers + self.load_events[layer_index].record() + + def wait(self, layer_index: int): + torch.cuda.current_stream().wait_event(self.load_events[layer_index]) + + @property + def finish_event(self): + return self.load_events[-1] + + class LayerDoneCounter: - def __init__(self, num_layers): + def __init__(self, num_layers: int): self.num_layers = num_layers # extra producer and consumer counters for overlap mode self.num_counters = 3 - self.counters = [num_layers] * self.num_counters - self.conditions = [threading.Condition() for _ in range(self.num_counters)] - self.producer_index = 0 - self.consumer_index = 0 - - def next_producer(self): - return (self.producer_index + 1) % self.num_counters + self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)] + self.producer_index = -1 + self.consumer_index = -1 def update_producer(self): - self.producer_index = self.next_producer() + self.producer_index = (self.producer_index + 1) % self.num_counters + assert self.events[ + self.producer_index + ].finish_event.query(), ( + "Producer finish event should be ready before being reused." + ) return self.producer_index - def set_consumer(self, index): + def set_consumer(self, index: int): self.consumer_index = index - def increment(self): - with self.conditions[self.producer_index]: - self.counters[self.producer_index] += 1 - self.conditions[self.producer_index].notify_all() - - def wait_until(self, threshold): - with self.conditions[self.consumer_index]: - while self.counters[self.consumer_index] <= threshold: - self.conditions[self.consumer_index].wait() + def wait_until(self, threshold: int): + if self.consumer_index < 0: + return + self.events[self.consumer_index].wait(threshold) def reset(self): - with self.conditions[self.producer_index]: - self.counters[self.producer_index] = 0 + self.producer_index = -1 + self.consumer_index = -1 class CacheOperation: @@ -99,38 +113,32 @@ class CacheOperation: # default priority is the order of creation self.priority = priority if priority is not None else self.id - def merge(self, other: "CacheOperation") -> None: - # multiple operations can be merged into a single operation for batch processing - self.host_indices = torch.cat([self.host_indices, other.host_indices]) - self.device_indices = torch.cat([self.device_indices, other.device_indices]) - self.priority = min(self.priority, other.priority) - self.node_ids.extend(other.node_ids) + @staticmethod + def merge_ops(ops: List[CacheOperation]) -> CacheOperation: + assert len(ops) > 0 + if len(ops) == 1: + return ops[0] - def split(self, factor) -> List["CacheOperation"]: - # split an operation into smaller operations to reduce the size of intermediate buffers - if factor <= 1: - return [self] + host_indices = torch.cat([op.host_indices for op in ops]) + device_indices = torch.cat([op.device_indices for op in ops]) + node_ids = [] + priority = min(op.priority for op in ops) + for op in ops: + node_ids.extend(op.node_ids) + merged_op = CacheOperation(host_indices, device_indices, -1, priority) + merged_op.node_ids = node_ids + return merged_op - chunk_size = math.ceil(len(self.host_indices) / factor) - split_ops = [] - for i in range(0, len(self.host_indices), chunk_size): - split_ops.append( - CacheOperation( - host_indices=self.host_indices[i : i + chunk_size], - device_indices=self.device_indices[i : i + chunk_size], - node_id=0, - ) - ) - # Inherit the node_ids on the final chunk - if split_ops: - split_ops[-1].node_ids = self.node_ids - - return split_ops - - def __lt__(self, other: "CacheOperation"): + def __lt__(self, other: CacheOperation): return self.priority < other.priority +class HiCacheAck(NamedTuple): + start_event: torch.cuda.Event + finish_event: torch.cuda.Event + node_ids: List[int] + + class TransferBuffer: """ Overlapping buffer preparation and transfer operations to improve throughput. @@ -236,7 +244,7 @@ class HiCacheController: mem_pool_host: HostKVCache, page_size: int, tp_group: torch.distributed.ProcessGroup, - load_cache_event: threading.Event = None, + load_cache_event: threading.Event, write_policy: str = "write_through_selective", io_backend: str = "", storage_backend: Optional[str] = None, @@ -340,8 +348,9 @@ class HiCacheController: self.page_set_func = self._3fs_zero_copy_page_set self.batch_exists_func = self._3fs_zero_copy_batch_exists - self.load_cache_event = load_cache_event - self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) + self.device = self.mem_pool_device.device + self.layer_num = self.mem_pool_device.layer_num + self.layer_done_counter = LayerDoneCounter(self.layer_num) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) if write_policy not in [ @@ -351,11 +360,11 @@ class HiCacheController: ]: raise ValueError(f"Invalid write policy: {write_policy}") - self.write_queue = PriorityQueue() - self.load_queue = PriorityQueue() - - self.ack_write_queue = Queue() - self.ack_load_queue = Queue() + # self.write_queue = PriorityQueue[CacheOperation]() + self.load_queue: List[CacheOperation] = [] + self.write_queue: List[CacheOperation] = [] + self.ack_load_queue: List[HiCacheAck] = [] + self.ack_write_queue: List[HiCacheAck] = [] self.stop_event = threading.Event() self.write_buffer = TransferBuffer(self.stop_event) @@ -366,16 +375,6 @@ class HiCacheController: self.write_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream() - self.write_thread = threading.Thread( - target=self.write_thread_func_direct, daemon=True - ) - self.load_thread = threading.Thread( - target=self.load_thread_func_layer_by_layer, daemon=True - ) - - self.write_thread.start() - self.load_thread.start() - if self.enable_storage: self.prefetch_thread = threading.Thread( target=self.prefetch_thread_func, daemon=True @@ -432,15 +431,13 @@ class HiCacheController: def reset(self): self.stop_event.set() - self.write_thread.join() - self.load_thread.join() - self.write_queue.queue.clear() - self.load_queue.queue.clear() + self.write_queue.clear() + self.load_queue.clear() self.write_buffer.clear() self.load_buffer.clear() - self.ack_write_queue.queue.clear() - self.ack_load_queue.queue.clear() + self.ack_write_queue.clear() + self.ack_load_queue.clear() if self.enable_storage: self.prefetch_thread.join() self.backup_thread.join() @@ -449,15 +446,7 @@ class HiCacheController: self.prefetch_revoke_queue.queue.clear() self.ack_backup_queue.queue.clear() - self.write_thread = threading.Thread( - target=self.write_thread_func_direct, daemon=True - ) - self.load_thread = threading.Thread( - target=self.load_thread_func_layer_by_layer, daemon=True - ) self.stop_event.clear() - self.write_thread.start() - self.load_thread.start() if self.enable_storage: self.prefetch_thread = threading.Thread( @@ -473,7 +462,7 @@ class HiCacheController: self, device_indices: torch.Tensor, priority: Optional[int] = None, - node_id: int = 0, + node_id: int = -1, ) -> Optional[torch.Tensor]: """ Back up KV caches from device memory to host memory. @@ -482,17 +471,46 @@ class HiCacheController: if host_indices is None: return None self.mem_pool_host.protect_write(host_indices) - torch.cuda.current_stream().synchronize() - self.write_queue.put( + self.write_queue.append( CacheOperation(host_indices, device_indices, node_id, priority) ) + self.start_writing() return host_indices + def start_writing(self) -> None: + if len(self.write_queue) == 0: + return + + op = CacheOperation.merge_ops(self.write_queue) + host_indices, device_indices = self.move_indices(op) + self.write_queue.clear() + + start_event = torch.cuda.Event() + finish_event = torch.cuda.Event() + + start_event.record() + with torch.cuda.stream(self.write_stream): + start_event.wait(self.write_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, host_indices, device_indices, self.io_backend + ) + self.mem_pool_host.complete_io(op.host_indices) + finish_event.record() + # NOTE: We must save the host indices and device indices here, + # this is because we need to guarantee that these tensors are + # still alive when the write stream is executing. + if host_indices.is_cuda: + host_indices.record_stream(self.write_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.write_stream) + + self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids)) + def load( self, host_indices: torch.Tensor, priority: Optional[int] = None, - node_id: int = 0, + node_id: int = -1, ) -> Optional[torch.Tensor]: """ Load KV caches from host memory to device memory. @@ -501,17 +519,18 @@ class HiCacheController: if device_indices is None: return None self.mem_pool_host.protect_load(host_indices) - # to ensure the device indices are ready before accessed by another CUDA stream - torch.cuda.current_stream().synchronize() - self.load_queue.put( + self.load_queue.append( CacheOperation(host_indices, device_indices, node_id, priority) ) return device_indices - def move_indices(self, host_indices, device_indices): + def move_indices(self, op: CacheOperation): + host_indices, device_indices = op.host_indices, op.device_indices # move indices to GPU if using kernels, to host if using direct indexing if self.io_backend == "kernel": - return host_indices.to(self.mem_pool_device.device), device_indices + if not host_indices.is_cuda: + host_indices = host_indices.to(self.device, non_blocking=True) + return host_indices, device_indices elif self.io_backend == "direct": device_indices = device_indices.cpu() host_indices, idx = host_indices.sort() @@ -519,58 +538,20 @@ class HiCacheController: else: raise ValueError(f"Unsupported io backend") - def write_thread_func_direct(self): - """ - Directly write through KV caches to host memory without buffering. - """ - torch.cuda.set_stream(self.write_stream) - while not self.stop_event.is_set(): - try: - operation = self.write_queue.get(block=True, timeout=1) - host_indices, device_indices = self.move_indices( - operation.host_indices, operation.device_indices - ) - self.mem_pool_host.backup_from_device_all_layer( - self.mem_pool_device, host_indices, device_indices, self.io_backend - ) - self.write_stream.synchronize() - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_write_queue.put(node_id) - except Empty: - continue - except Exception as e: - logger.error(e) + def start_loading(self) -> int: + if len(self.load_queue) == 0: + return -1 - def load_thread_func_layer_by_layer(self): - """ - Load KV caches from host memory to device memory layer by layer. - """ - torch.cuda.set_stream(self.load_stream) - while not self.stop_event.is_set(): - self.load_cache_event.wait(timeout=1) - if not self.load_cache_event.is_set(): - continue - self.load_cache_event.clear() - self.layer_done_counter.update_producer() + producer_id = self.layer_done_counter.update_producer() + op = CacheOperation.merge_ops(self.load_queue) + host_indices, device_indices = self.move_indices(op) + self.load_queue.clear() + producer_event = self.layer_done_counter.events[producer_id] + producer_event.start_event.record() - batch_operation = None - while self.load_queue.qsize() > 0: - op = self.load_queue.get(block=True) - if batch_operation is None: - batch_operation = op - else: - batch_operation.merge(op) - if batch_operation is None: - continue - - # start layer-wise KV cache transfer from CPU to GPU - self.layer_done_counter.reset() - host_indices, device_indices = self.move_indices( - batch_operation.host_indices, batch_operation.device_indices - ) - for i in range(self.mem_pool_host.layer_num): + with torch.cuda.stream(self.load_stream): + producer_event.start_event.wait(self.load_stream) + for i in range(self.layer_num): self.mem_pool_host.load_to_device_per_layer( self.mem_pool_device, host_indices, @@ -578,13 +559,24 @@ class HiCacheController: i, self.io_backend, ) - self.load_stream.synchronize() - self.layer_done_counter.increment() + producer_event.complete(i) + self.mem_pool_host.complete_io(op.host_indices) + # NOTE: We must save the host indices and device indices here, + # this is because we need to guarantee that these tensors are + # still alive when the load stream is executing. + if host_indices.is_cuda: + host_indices.record_stream(self.load_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.load_stream) - self.mem_pool_host.complete_io(batch_operation.host_indices) - for node_id in batch_operation.node_ids: - if node_id != 0: - self.ack_load_queue.put(node_id) + self.ack_load_queue.append( + HiCacheAck( + start_event=producer_event.start_event, + finish_event=producer_event.finish_event, + node_ids=op.node_ids, + ) + ) + return producer_id def evict_device( self, device_indices: torch.Tensor, host_indices: torch.Tensor diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index df5ade906..f519224df 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): is_prefill_only: bool = False # hicache pointer for synchronizing data loading from CPU to GPU - hicache_consumer_index: int = 0 + hicache_consumer_index: int = -1 @classmethod def init_new( @@ -1897,7 +1897,7 @@ class ModelWorkerBatch: spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None - hicache_consumer_index: int = 0 + hicache_consumer_index: int = -1 # Overlap event launch_done: Optional[threading.Event] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a65d91e8f..9e3af2eaa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1807,10 +1807,6 @@ class Scheduler( if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() - # update the consumer index of hicache to the running batch - self.tp_worker.set_hicache_consumer( - model_worker_batch.hicache_consumer_index - ) if self.pp_group.is_last_rank: logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fbc12e5b0..017f9a1f8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -12,10 +12,11 @@ # limitations under the License. # ============================================================================== """A tensor parallel worker.""" +from __future__ import annotations import logging import threading -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch @@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server_args import ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed +if TYPE_CHECKING: + from sglang.srt.managers.cache_controller import LayerDoneCounter + logger = logging.getLogger(__name__) @@ -167,10 +171,10 @@ class TpModelWorker: self.hicache_layer_transfer_counter = None - def register_hicache_layer_transfer_counter(self, counter): + def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): self.hicache_layer_transfer_counter = counter - def set_hicache_consumer(self, consumer_index): + def set_hicache_consumer(self, consumer_index: int): if self.hicache_layer_transfer_counter is not None: self.hicache_layer_transfer_counter.set_consumer(consumer_index) @@ -230,6 +234,9 @@ class TpModelWorker: ) -> Tuple[ Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool ]: + # update the consumer index of hicache to the running batch + self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) pp_proxy_tensors = None diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 674a94195..e72d4fb6e 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -12,13 +12,14 @@ # limitations under the License. # ============================================================================== """A tensor parallel worker.""" +from __future__ import annotations import dataclasses import logging import signal import threading from queue import Queue -from typing import Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import psutil import torch @@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import DynamicGradMode, get_compiler_backend from sglang.utils import get_exception_traceback +if TYPE_CHECKING: + from sglang.srt.managers.cache_controller import LayerDoneCounter + logger = logging.getLogger(__name__) @@ -79,7 +83,7 @@ class TpModelWorkerClient: ) # Launch threads - self.input_queue = Queue() + self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]() self.output_queue = Queue() self.forward_stream = torch.get_device_module(self.device).Stream() self.forward_thread = threading.Thread( @@ -93,13 +97,9 @@ class TpModelWorkerClient: self.hicache_layer_transfer_counter = None - def register_hicache_layer_transfer_counter(self, counter): + def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): self.hicache_layer_transfer_counter = counter - def set_hicache_consumer(self, consumer_index): - if self.hicache_layer_transfer_counter is not None: - self.hicache_layer_transfer_counter.set_consumer(consumer_index) - def get_worker_info(self): return self.worker.get_worker_info() @@ -147,7 +147,7 @@ class TpModelWorkerClient: @DynamicGradMode() def forward_thread_func_(self): batch_pt = 0 - batch_lists = [None] * 2 + batch_lists: List = [None] * 2 while True: model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get() @@ -169,8 +169,6 @@ class TpModelWorkerClient: input_ids = model_worker_batch.input_ids resolve_future_token_ids(input_ids, self.future_token_ids_map) - # update the consumer index of hicache to the running batch - self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) # Run forward logits_output, next_token_ids, can_run_cuda_graph = ( self.worker.forward_batch_generation( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 5883c1f15..3b00e4619 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -201,41 +201,57 @@ class HiRadixCache(RadixCache): if write_back: # blocking till all write back complete while len(self.ongoing_write_through) > 0: - ack_id = self.cache_controller.ack_write_queue.get() - del self.ongoing_write_through[ack_id] + for _, finish_event, ack_list in self.cache_controller.ack_write_queue: + finish_event.synchronize() + for ack_id in ack_list: + del self.ongoing_write_through[ack_id] + self.cache_controller.ack_write_queue.clear() + assert len(self.ongoing_write_through) == 0 return - queue_size = torch.tensor( - self.cache_controller.ack_write_queue.qsize(), dtype=torch.int - ) + + # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty + if len(self.ongoing_write_through) == 0: + return + + finish_count = 0 + for _, finish_event, ack_list in self.cache_controller.ack_write_queue: + if not finish_event.query(): + break + finish_count += 1 + queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu") if self.tp_world_size > 1: - # synchrnoize TP workers to make the same update to radix cache + # synchronize TP workers to make the same update to radix cache torch.distributed.all_reduce( queue_size, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) - for _ in range(queue_size.item()): - ack_id = self.cache_controller.ack_write_queue.get() - backuped_node = self.ongoing_write_through[ack_id] - self.dec_lock_ref(backuped_node) - del self.ongoing_write_through[ack_id] - if self.enable_storage: - self.write_backup_storage(backuped_node) + + finish_count = int(queue_size.item()) + while finish_count > 0: + _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0) + finish_event.synchronize() + for ack_id in ack_list: + backuped_node = self.ongoing_write_through.pop(ack_id) + self.dec_lock_ref(backuped_node) + if self.enable_storage: + self.write_backup_storage(backuped_node) + finish_count -= 1 def loading_check(self): - while not self.cache_controller.ack_load_queue.empty(): - try: - ack_id = self.cache_controller.ack_load_queue.get_nowait() - start_node, end_node = self.ongoing_load_back[ack_id] - self.dec_lock_ref(end_node) - while end_node != start_node: - assert end_node.loading - end_node.loading = False - end_node = end_node.parent - # clear the reference - del self.ongoing_load_back[ack_id] - except Exception: + finish_count = 0 + for _, finish_event, ack_list in self.cache_controller.ack_load_queue: + if not finish_event.query(): + # the KV cache loading is still ongoing break + finish_count += 1 + # no need to sync across TP workers as batch forwarding is synced + for ack_id in ack_list: + end_node = self.ongoing_load_back.pop(ack_id) + self.dec_lock_ref(end_node) + + # ACK until all events are processed + del self.cache_controller.ack_load_queue[:finish_count] def evictable_size(self): return self.evictable_size_ @@ -360,12 +376,11 @@ class HiRadixCache(RadixCache): # no sufficient GPU memory to load back KV caches return None - self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node) + self.ongoing_load_back[last_hit_node.id] = last_hit_node offset = 0 for node in nodes_to_load: node.value = device_indices[offset : offset + len(node.host_value)] offset += len(node.host_value) - node.loading = True self.evictable_size_ += len(device_indices) self.inc_lock_ref(last_hit_node) @@ -394,10 +409,12 @@ class HiRadixCache(RadixCache): last_node, ) - def ready_to_load_host_cache(self): - producer_index = self.cache_controller.layer_done_counter.next_producer() - self.load_cache_event.set() - return producer_index + def ready_to_load_host_cache(self) -> int: + """ + Notify the cache controller to start the KV cache loading. + Return the consumer index for the schedule batch manager to track. + """ + return self.cache_controller.start_loading() def check_hicache_events(self): self.writing_check() @@ -702,7 +719,6 @@ class HiRadixCache(RadixCache): new_node.parent = child.parent new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] - new_node.loading = child.loading new_node.hit_count = child.hit_count # split value and host value if exists diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index fab917a81..175440a3f 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import annotations + from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter """ @@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache. import abc import logging from contextlib import nullcontext -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 +if TYPE_CHECKING: + from sglang.srt.managers.cache_controller import LayerDoneCounter + logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 @@ -175,7 +180,7 @@ class KVCache(abc.ABC): ) -> None: raise NotImplementedError() - def register_layer_transfer_counter(self, layer_transfer_counter): + def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter): self.layer_transfer_counter = layer_transfer_counter def get_cpu_copy(self, indices): diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 15b5efe5a..dc27eaa03 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -3,6 +3,7 @@ import logging import threading from enum import IntEnum from functools import wraps +from typing import Optional import psutil import torch @@ -169,7 +170,7 @@ class HostKVCache(abc.ABC): return len(self.free_slots) @synchronized() - def alloc(self, need_size: int) -> torch.Tensor: + def alloc(self, need_size: int) -> Optional[torch.Tensor]: assert ( need_size % self.page_size == 0 ), "The requested size should be a multiple of the page size." diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index b0cf0bb9c..d8208e143 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -53,8 +53,6 @@ class TreeNode: self.last_access_time = time.monotonic() self.hit_count = 0 - # indicating the node is loading KV cache from host - self.loading = False # indicating the node is locked to protect from eviction # incremented when the node is referenced by a storage operation self.host_ref_counter = 0 diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 0624e84e1..686fc6ab0 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -60,8 +60,6 @@ class TreeNode: self.last_access_time = time.monotonic() self.hit_count = 0 - # indicating the node is loading KV cache from host - self.loading = False # store the host indices of KV cache self.host_value = None