[Refactor] Remove Hicache Load & Write threads (#10127)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
DarkSharpness
2025-09-08 22:18:50 -07:00
committed by GitHub
parent cdc56ef6c1
commit 948b01a04c
10 changed files with 215 additions and 204 deletions

View File

@@ -18,7 +18,7 @@ import math
import threading import threading
import time import time
from queue import Empty, Full, PriorityQueue, Queue from queue import Empty, Full, PriorityQueue, Queue
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
import torch import torch
@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LayerLoadingEvent:
def __init__(self, num_layers: int):
self._num_layers = num_layers
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
self.start_event = torch.cuda.Event() # start event on controller stream
def complete(self, layer_index: int):
assert 0 <= layer_index < self._num_layers
self.load_events[layer_index].record()
def wait(self, layer_index: int):
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
@property
def finish_event(self):
return self.load_events[-1]
class LayerDoneCounter: class LayerDoneCounter:
def __init__(self, num_layers): def __init__(self, num_layers: int):
self.num_layers = num_layers self.num_layers = num_layers
# extra producer and consumer counters for overlap mode # extra producer and consumer counters for overlap mode
self.num_counters = 3 self.num_counters = 3
self.counters = [num_layers] * self.num_counters self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
self.conditions = [threading.Condition() for _ in range(self.num_counters)] self.producer_index = -1
self.producer_index = 0 self.consumer_index = -1
self.consumer_index = 0
def next_producer(self):
return (self.producer_index + 1) % self.num_counters
def update_producer(self): def update_producer(self):
self.producer_index = self.next_producer() self.producer_index = (self.producer_index + 1) % self.num_counters
assert self.events[
self.producer_index
].finish_event.query(), (
"Producer finish event should be ready before being reused."
)
return self.producer_index return self.producer_index
def set_consumer(self, index): def set_consumer(self, index: int):
self.consumer_index = index self.consumer_index = index
def increment(self): def wait_until(self, threshold: int):
with self.conditions[self.producer_index]: if self.consumer_index < 0:
self.counters[self.producer_index] += 1 return
self.conditions[self.producer_index].notify_all() self.events[self.consumer_index].wait(threshold)
def wait_until(self, threshold):
with self.conditions[self.consumer_index]:
while self.counters[self.consumer_index] <= threshold:
self.conditions[self.consumer_index].wait()
def reset(self): def reset(self):
with self.conditions[self.producer_index]: self.producer_index = -1
self.counters[self.producer_index] = 0 self.consumer_index = -1
class CacheOperation: class CacheOperation:
@@ -99,38 +113,32 @@ class CacheOperation:
# default priority is the order of creation # default priority is the order of creation
self.priority = priority if priority is not None else self.id self.priority = priority if priority is not None else self.id
def merge(self, other: "CacheOperation") -> None: @staticmethod
# multiple operations can be merged into a single operation for batch processing def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
self.host_indices = torch.cat([self.host_indices, other.host_indices]) assert len(ops) > 0
self.device_indices = torch.cat([self.device_indices, other.device_indices]) if len(ops) == 1:
self.priority = min(self.priority, other.priority) return ops[0]
self.node_ids.extend(other.node_ids)
def split(self, factor) -> List["CacheOperation"]: host_indices = torch.cat([op.host_indices for op in ops])
# split an operation into smaller operations to reduce the size of intermediate buffers device_indices = torch.cat([op.device_indices for op in ops])
if factor <= 1: node_ids = []
return [self] priority = min(op.priority for op in ops)
for op in ops:
node_ids.extend(op.node_ids)
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
merged_op.node_ids = node_ids
return merged_op
chunk_size = math.ceil(len(self.host_indices) / factor) def __lt__(self, other: CacheOperation):
split_ops = []
for i in range(0, len(self.host_indices), chunk_size):
split_ops.append(
CacheOperation(
host_indices=self.host_indices[i : i + chunk_size],
device_indices=self.device_indices[i : i + chunk_size],
node_id=0,
)
)
# Inherit the node_ids on the final chunk
if split_ops:
split_ops[-1].node_ids = self.node_ids
return split_ops
def __lt__(self, other: "CacheOperation"):
return self.priority < other.priority return self.priority < other.priority
class HiCacheAck(NamedTuple):
start_event: torch.cuda.Event
finish_event: torch.cuda.Event
node_ids: List[int]
class TransferBuffer: class TransferBuffer:
""" """
Overlapping buffer preparation and transfer operations to improve throughput. Overlapping buffer preparation and transfer operations to improve throughput.
@@ -236,7 +244,7 @@ class HiCacheController:
mem_pool_host: HostKVCache, mem_pool_host: HostKVCache,
page_size: int, page_size: int,
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event = None, load_cache_event: threading.Event,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
io_backend: str = "", io_backend: str = "",
storage_backend: Optional[str] = None, storage_backend: Optional[str] = None,
@@ -340,8 +348,9 @@ class HiCacheController:
self.page_set_func = self._3fs_zero_copy_page_set self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.load_cache_event = load_cache_event self.device = self.mem_pool_device.device
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_num = self.mem_pool_device.layer_num
self.layer_done_counter = LayerDoneCounter(self.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
if write_policy not in [ if write_policy not in [
@@ -351,11 +360,11 @@ class HiCacheController:
]: ]:
raise ValueError(f"Invalid write policy: {write_policy}") raise ValueError(f"Invalid write policy: {write_policy}")
self.write_queue = PriorityQueue() # self.write_queue = PriorityQueue[CacheOperation]()
self.load_queue = PriorityQueue() self.load_queue: List[CacheOperation] = []
self.write_queue: List[CacheOperation] = []
self.ack_write_queue = Queue() self.ack_load_queue: List[HiCacheAck] = []
self.ack_load_queue = Queue() self.ack_write_queue: List[HiCacheAck] = []
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.write_buffer = TransferBuffer(self.stop_event) self.write_buffer = TransferBuffer(self.stop_event)
@@ -366,16 +375,6 @@ class HiCacheController:
self.write_stream = torch.cuda.Stream() self.write_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.write_thread.start()
self.load_thread.start()
if self.enable_storage: if self.enable_storage:
self.prefetch_thread = threading.Thread( self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True target=self.prefetch_thread_func, daemon=True
@@ -432,15 +431,13 @@ class HiCacheController:
def reset(self): def reset(self):
self.stop_event.set() self.stop_event.set()
self.write_thread.join()
self.load_thread.join()
self.write_queue.queue.clear() self.write_queue.clear()
self.load_queue.queue.clear() self.load_queue.clear()
self.write_buffer.clear() self.write_buffer.clear()
self.load_buffer.clear() self.load_buffer.clear()
self.ack_write_queue.queue.clear() self.ack_write_queue.clear()
self.ack_load_queue.queue.clear() self.ack_load_queue.clear()
if self.enable_storage: if self.enable_storage:
self.prefetch_thread.join() self.prefetch_thread.join()
self.backup_thread.join() self.backup_thread.join()
@@ -449,15 +446,7 @@ class HiCacheController:
self.prefetch_revoke_queue.queue.clear() self.prefetch_revoke_queue.queue.clear()
self.ack_backup_queue.queue.clear() self.ack_backup_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.stop_event.clear() self.stop_event.clear()
self.write_thread.start()
self.load_thread.start()
if self.enable_storage: if self.enable_storage:
self.prefetch_thread = threading.Thread( self.prefetch_thread = threading.Thread(
@@ -473,7 +462,7 @@ class HiCacheController:
self, self,
device_indices: torch.Tensor, device_indices: torch.Tensor,
priority: Optional[int] = None, priority: Optional[int] = None,
node_id: int = 0, node_id: int = -1,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Back up KV caches from device memory to host memory. Back up KV caches from device memory to host memory.
@@ -482,17 +471,46 @@ class HiCacheController:
if host_indices is None: if host_indices is None:
return None return None
self.mem_pool_host.protect_write(host_indices) self.mem_pool_host.protect_write(host_indices)
torch.cuda.current_stream().synchronize() self.write_queue.append(
self.write_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
self.start_writing()
return host_indices return host_indices
def start_writing(self) -> None:
if len(self.write_queue) == 0:
return
op = CacheOperation.merge_ops(self.write_queue)
host_indices, device_indices = self.move_indices(op)
self.write_queue.clear()
start_event = torch.cuda.Event()
finish_event = torch.cuda.Event()
start_event.record()
with torch.cuda.stream(self.write_stream):
start_event.wait(self.write_stream)
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the write stream is executing.
if host_indices.is_cuda:
host_indices.record_stream(self.write_stream)
if device_indices.is_cuda:
device_indices.record_stream(self.write_stream)
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
def load( def load(
self, self,
host_indices: torch.Tensor, host_indices: torch.Tensor,
priority: Optional[int] = None, priority: Optional[int] = None,
node_id: int = 0, node_id: int = -1,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Load KV caches from host memory to device memory. Load KV caches from host memory to device memory.
@@ -501,17 +519,18 @@ class HiCacheController:
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices) self.mem_pool_host.protect_load(host_indices)
# to ensure the device indices are ready before accessed by another CUDA stream self.load_queue.append(
torch.cuda.current_stream().synchronize()
self.load_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
return device_indices return device_indices
def move_indices(self, host_indices, device_indices): def move_indices(self, op: CacheOperation):
host_indices, device_indices = op.host_indices, op.device_indices
# move indices to GPU if using kernels, to host if using direct indexing # move indices to GPU if using kernels, to host if using direct indexing
if self.io_backend == "kernel": if self.io_backend == "kernel":
return host_indices.to(self.mem_pool_device.device), device_indices if not host_indices.is_cuda:
host_indices = host_indices.to(self.device, non_blocking=True)
return host_indices, device_indices
elif self.io_backend == "direct": elif self.io_backend == "direct":
device_indices = device_indices.cpu() device_indices = device_indices.cpu()
host_indices, idx = host_indices.sort() host_indices, idx = host_indices.sort()
@@ -519,58 +538,20 @@ class HiCacheController:
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
def write_thread_func_direct(self): def start_loading(self) -> int:
""" if len(self.load_queue) == 0:
Directly write through KV caches to host memory without buffering. return -1
"""
torch.cuda.set_stream(self.write_stream)
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
host_indices, device_indices = self.move_indices(
operation.host_indices, operation.device_indices
)
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
def load_thread_func_layer_by_layer(self): producer_id = self.layer_done_counter.update_producer()
""" op = CacheOperation.merge_ops(self.load_queue)
Load KV caches from host memory to device memory layer by layer. host_indices, device_indices = self.move_indices(op)
""" self.load_queue.clear()
torch.cuda.set_stream(self.load_stream) producer_event = self.layer_done_counter.events[producer_id]
while not self.stop_event.is_set(): producer_event.start_event.record()
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
self.layer_done_counter.update_producer()
batch_operation = None with torch.cuda.stream(self.load_stream):
while self.load_queue.qsize() > 0: producer_event.start_event.wait(self.load_stream)
op = self.load_queue.get(block=True) for i in range(self.layer_num):
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
# start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset()
host_indices, device_indices = self.move_indices(
batch_operation.host_indices, batch_operation.device_indices
)
for i in range(self.mem_pool_host.layer_num):
self.mem_pool_host.load_to_device_per_layer( self.mem_pool_host.load_to_device_per_layer(
self.mem_pool_device, self.mem_pool_device,
host_indices, host_indices,
@@ -578,13 +559,24 @@ class HiCacheController:
i, i,
self.io_backend, self.io_backend,
) )
self.load_stream.synchronize() producer_event.complete(i)
self.layer_done_counter.increment() self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
if host_indices.is_cuda:
host_indices.record_stream(self.load_stream)
if device_indices.is_cuda:
device_indices.record_stream(self.load_stream)
self.mem_pool_host.complete_io(batch_operation.host_indices) self.ack_load_queue.append(
for node_id in batch_operation.node_ids: HiCacheAck(
if node_id != 0: start_event=producer_event.start_event,
self.ack_load_queue.put(node_id) finish_event=producer_event.finish_event,
node_ids=op.node_ids,
)
)
return producer_id
def evict_device( def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor self, device_indices: torch.Tensor, host_indices: torch.Tensor

View File

@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_prefill_only: bool = False is_prefill_only: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU # hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0 hicache_consumer_index: int = -1
@classmethod @classmethod
def init_new( def init_new(
@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = 0 hicache_consumer_index: int = -1
# Overlap event # Overlap event
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None

View File

@@ -1807,10 +1807,6 @@ class Scheduler(
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
# update the consumer index of hicache to the running batch
self.tp_worker.set_hicache_consumer(
model_worker_batch.hicache_consumer_index
)
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch) self.tp_worker.forward_batch_generation(model_worker_batch)

View File

@@ -12,10 +12,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A tensor parallel worker.""" """A tensor parallel worker."""
from __future__ import annotations
import logging import logging
import threading import threading
from typing import Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch import torch
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -167,10 +171,10 @@ class TpModelWorker:
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index): def set_hicache_consumer(self, consumer_index: int):
if self.hicache_layer_transfer_counter is not None: if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index) self.hicache_layer_transfer_counter.set_consumer(consumer_index)
@@ -230,6 +234,9 @@ class TpModelWorker:
) -> Tuple[ ) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]: ]:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
pp_proxy_tensors = None pp_proxy_tensors = None

View File

@@ -12,13 +12,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A tensor parallel worker.""" """A tensor parallel worker."""
from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import signal import signal
import threading import threading
from queue import Queue from queue import Queue
from typing import Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import psutil import psutil
import torch import torch
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode, get_compiler_backend from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
) )
# Launch threads # Launch threads
self.input_queue = Queue() self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
self.output_queue = Queue() self.output_queue = Queue()
self.forward_stream = torch.get_device_module(self.device).Stream() self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread( self.forward_thread = threading.Thread(
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter self.hicache_layer_transfer_counter = counter
def set_hicache_consumer(self, consumer_index):
if self.hicache_layer_transfer_counter is not None:
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
@DynamicGradMode() @DynamicGradMode()
def forward_thread_func_(self): def forward_thread_func_(self):
batch_pt = 0 batch_pt = 0
batch_lists = [None] * 2 batch_lists: List = [None] * 2
while True: while True:
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get() model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map) resolve_future_token_ids(input_ids, self.future_token_ids_map)
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation( self.worker.forward_batch_generation(

View File

@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
if write_back: if write_back:
# blocking till all write back complete # blocking till all write back complete
while len(self.ongoing_write_through) > 0: while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get() for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
del self.ongoing_write_through[ack_id] finish_event.synchronize()
for ack_id in ack_list:
del self.ongoing_write_through[ack_id]
self.cache_controller.ack_write_queue.clear()
assert len(self.ongoing_write_through) == 0
return return
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
) if len(self.ongoing_write_through) == 0:
return
finish_count = 0
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
if not finish_event.query():
break
finish_count += 1
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
if self.tp_world_size > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache # synchronize TP workers to make the same update to radix cache
torch.distributed.all_reduce( torch.distributed.all_reduce(
queue_size, queue_size,
op=torch.distributed.ReduceOp.MIN, op=torch.distributed.ReduceOp.MIN,
group=self.tp_group, group=self.tp_group,
) )
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get() finish_count = int(queue_size.item())
backuped_node = self.ongoing_write_through[ack_id] while finish_count > 0:
self.dec_lock_ref(backuped_node) _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
del self.ongoing_write_through[ack_id] finish_event.synchronize()
if self.enable_storage: for ack_id in ack_list:
self.write_backup_storage(backuped_node) backuped_node = self.ongoing_write_through.pop(ack_id)
self.dec_lock_ref(backuped_node)
if self.enable_storage:
self.write_backup_storage(backuped_node)
finish_count -= 1
def loading_check(self): def loading_check(self):
while not self.cache_controller.ack_load_queue.empty(): finish_count = 0
try: for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
ack_id = self.cache_controller.ack_load_queue.get_nowait() if not finish_event.query():
start_node, end_node = self.ongoing_load_back[ack_id] # the KV cache loading is still ongoing
self.dec_lock_ref(end_node)
while end_node != start_node:
assert end_node.loading
end_node.loading = False
end_node = end_node.parent
# clear the reference
del self.ongoing_load_back[ack_id]
except Exception:
break break
finish_count += 1
# no need to sync across TP workers as batch forwarding is synced
for ack_id in ack_list:
end_node = self.ongoing_load_back.pop(ack_id)
self.dec_lock_ref(end_node)
# ACK until all events are processed
del self.cache_controller.ack_load_queue[:finish_count]
def evictable_size(self): def evictable_size(self):
return self.evictable_size_ return self.evictable_size_
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
# no sufficient GPU memory to load back KV caches # no sufficient GPU memory to load back KV caches
return None return None
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node) self.ongoing_load_back[last_hit_node.id] = last_hit_node
offset = 0 offset = 0
for node in nodes_to_load: for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)] node.value = device_indices[offset : offset + len(node.host_value)]
offset += len(node.host_value) offset += len(node.host_value)
node.loading = True
self.evictable_size_ += len(device_indices) self.evictable_size_ += len(device_indices)
self.inc_lock_ref(last_hit_node) self.inc_lock_ref(last_hit_node)
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
last_node, last_node,
) )
def ready_to_load_host_cache(self): def ready_to_load_host_cache(self) -> int:
producer_index = self.cache_controller.layer_done_counter.next_producer() """
self.load_cache_event.set() Notify the cache controller to start the KV cache loading.
return producer_index Return the consumer index for the schedule batch manager to track.
"""
return self.cache_controller.start_loading()
def check_hicache_events(self): def check_hicache_events(self):
self.writing_check() self.writing_check()
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
new_node.parent = child.parent new_node.parent = child.parent
new_node.lock_ref = child.lock_ref new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len] new_node.key = child.key[:split_len]
new_node.loading = child.loading
new_node.hit_count = child.hit_count new_node.hit_count = child.hit_count
# split value and host value if exists # split value and host value if exists

View File

@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from __future__ import annotations
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
""" """
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
from contextlib import nullcontext from contextlib import nullcontext
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024 GB = 1024 * 1024 * 1024
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter): def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
self.layer_transfer_counter = layer_transfer_counter self.layer_transfer_counter = layer_transfer_counter
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):

View File

@@ -3,6 +3,7 @@ import logging
import threading import threading
from enum import IntEnum from enum import IntEnum
from functools import wraps from functools import wraps
from typing import Optional
import psutil import psutil
import torch import torch
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
return len(self.free_slots) return len(self.free_slots)
@synchronized() @synchronized()
def alloc(self, need_size: int) -> torch.Tensor: def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert ( assert (
need_size % self.page_size == 0 need_size % self.page_size == 0
), "The requested size should be a multiple of the page size." ), "The requested size should be a multiple of the page size."

View File

@@ -53,8 +53,6 @@ class TreeNode:
self.last_access_time = time.monotonic() self.last_access_time = time.monotonic()
self.hit_count = 0 self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# indicating the node is locked to protect from eviction # indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation # incremented when the node is referenced by a storage operation
self.host_ref_counter = 0 self.host_ref_counter = 0

View File

@@ -60,8 +60,6 @@ class TreeNode:
self.last_access_time = time.monotonic() self.last_access_time = time.monotonic()
self.hit_count = 0 self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache # store the host indices of KV cache
self.host_value = None self.host_value = None