[Refactor] Remove Hicache Load & Write threads (#10127)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -18,7 +18,7 @@ import math
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Empty, Full, PriorityQueue, Queue
|
from queue import Empty, Full, PriorityQueue, Queue
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerLoadingEvent:
|
||||||
|
def __init__(self, num_layers: int):
|
||||||
|
self._num_layers = num_layers
|
||||||
|
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
self.start_event = torch.cuda.Event() # start event on controller stream
|
||||||
|
|
||||||
|
def complete(self, layer_index: int):
|
||||||
|
assert 0 <= layer_index < self._num_layers
|
||||||
|
self.load_events[layer_index].record()
|
||||||
|
|
||||||
|
def wait(self, layer_index: int):
|
||||||
|
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finish_event(self):
|
||||||
|
return self.load_events[-1]
|
||||||
|
|
||||||
|
|
||||||
class LayerDoneCounter:
|
class LayerDoneCounter:
|
||||||
def __init__(self, num_layers):
|
def __init__(self, num_layers: int):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
# extra producer and consumer counters for overlap mode
|
# extra producer and consumer counters for overlap mode
|
||||||
self.num_counters = 3
|
self.num_counters = 3
|
||||||
self.counters = [num_layers] * self.num_counters
|
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
||||||
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
|
self.producer_index = -1
|
||||||
self.producer_index = 0
|
self.consumer_index = -1
|
||||||
self.consumer_index = 0
|
|
||||||
|
|
||||||
def next_producer(self):
|
|
||||||
return (self.producer_index + 1) % self.num_counters
|
|
||||||
|
|
||||||
def update_producer(self):
|
def update_producer(self):
|
||||||
self.producer_index = self.next_producer()
|
self.producer_index = (self.producer_index + 1) % self.num_counters
|
||||||
|
assert self.events[
|
||||||
|
self.producer_index
|
||||||
|
].finish_event.query(), (
|
||||||
|
"Producer finish event should be ready before being reused."
|
||||||
|
)
|
||||||
return self.producer_index
|
return self.producer_index
|
||||||
|
|
||||||
def set_consumer(self, index):
|
def set_consumer(self, index: int):
|
||||||
self.consumer_index = index
|
self.consumer_index = index
|
||||||
|
|
||||||
def increment(self):
|
def wait_until(self, threshold: int):
|
||||||
with self.conditions[self.producer_index]:
|
if self.consumer_index < 0:
|
||||||
self.counters[self.producer_index] += 1
|
return
|
||||||
self.conditions[self.producer_index].notify_all()
|
self.events[self.consumer_index].wait(threshold)
|
||||||
|
|
||||||
def wait_until(self, threshold):
|
|
||||||
with self.conditions[self.consumer_index]:
|
|
||||||
while self.counters[self.consumer_index] <= threshold:
|
|
||||||
self.conditions[self.consumer_index].wait()
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
with self.conditions[self.producer_index]:
|
self.producer_index = -1
|
||||||
self.counters[self.producer_index] = 0
|
self.consumer_index = -1
|
||||||
|
|
||||||
|
|
||||||
class CacheOperation:
|
class CacheOperation:
|
||||||
@@ -99,38 +113,32 @@ class CacheOperation:
|
|||||||
# default priority is the order of creation
|
# default priority is the order of creation
|
||||||
self.priority = priority if priority is not None else self.id
|
self.priority = priority if priority is not None else self.id
|
||||||
|
|
||||||
def merge(self, other: "CacheOperation") -> None:
|
@staticmethod
|
||||||
# multiple operations can be merged into a single operation for batch processing
|
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
||||||
self.host_indices = torch.cat([self.host_indices, other.host_indices])
|
assert len(ops) > 0
|
||||||
self.device_indices = torch.cat([self.device_indices, other.device_indices])
|
if len(ops) == 1:
|
||||||
self.priority = min(self.priority, other.priority)
|
return ops[0]
|
||||||
self.node_ids.extend(other.node_ids)
|
|
||||||
|
|
||||||
def split(self, factor) -> List["CacheOperation"]:
|
host_indices = torch.cat([op.host_indices for op in ops])
|
||||||
# split an operation into smaller operations to reduce the size of intermediate buffers
|
device_indices = torch.cat([op.device_indices for op in ops])
|
||||||
if factor <= 1:
|
node_ids = []
|
||||||
return [self]
|
priority = min(op.priority for op in ops)
|
||||||
|
for op in ops:
|
||||||
|
node_ids.extend(op.node_ids)
|
||||||
|
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
||||||
|
merged_op.node_ids = node_ids
|
||||||
|
return merged_op
|
||||||
|
|
||||||
chunk_size = math.ceil(len(self.host_indices) / factor)
|
def __lt__(self, other: CacheOperation):
|
||||||
split_ops = []
|
|
||||||
for i in range(0, len(self.host_indices), chunk_size):
|
|
||||||
split_ops.append(
|
|
||||||
CacheOperation(
|
|
||||||
host_indices=self.host_indices[i : i + chunk_size],
|
|
||||||
device_indices=self.device_indices[i : i + chunk_size],
|
|
||||||
node_id=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Inherit the node_ids on the final chunk
|
|
||||||
if split_ops:
|
|
||||||
split_ops[-1].node_ids = self.node_ids
|
|
||||||
|
|
||||||
return split_ops
|
|
||||||
|
|
||||||
def __lt__(self, other: "CacheOperation"):
|
|
||||||
return self.priority < other.priority
|
return self.priority < other.priority
|
||||||
|
|
||||||
|
|
||||||
|
class HiCacheAck(NamedTuple):
|
||||||
|
start_event: torch.cuda.Event
|
||||||
|
finish_event: torch.cuda.Event
|
||||||
|
node_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
class TransferBuffer:
|
class TransferBuffer:
|
||||||
"""
|
"""
|
||||||
Overlapping buffer preparation and transfer operations to improve throughput.
|
Overlapping buffer preparation and transfer operations to improve throughput.
|
||||||
@@ -236,7 +244,7 @@ class HiCacheController:
|
|||||||
mem_pool_host: HostKVCache,
|
mem_pool_host: HostKVCache,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
tp_group: torch.distributed.ProcessGroup,
|
tp_group: torch.distributed.ProcessGroup,
|
||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
io_backend: str = "",
|
io_backend: str = "",
|
||||||
storage_backend: Optional[str] = None,
|
storage_backend: Optional[str] = None,
|
||||||
@@ -340,8 +348,9 @@ class HiCacheController:
|
|||||||
self.page_set_func = self._3fs_zero_copy_page_set
|
self.page_set_func = self._3fs_zero_copy_page_set
|
||||||
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.device = self.mem_pool_device.device
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_num = self.mem_pool_device.layer_num
|
||||||
|
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
||||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||||
|
|
||||||
if write_policy not in [
|
if write_policy not in [
|
||||||
@@ -351,11 +360,11 @@ class HiCacheController:
|
|||||||
]:
|
]:
|
||||||
raise ValueError(f"Invalid write policy: {write_policy}")
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
||||||
|
|
||||||
self.write_queue = PriorityQueue()
|
# self.write_queue = PriorityQueue[CacheOperation]()
|
||||||
self.load_queue = PriorityQueue()
|
self.load_queue: List[CacheOperation] = []
|
||||||
|
self.write_queue: List[CacheOperation] = []
|
||||||
self.ack_write_queue = Queue()
|
self.ack_load_queue: List[HiCacheAck] = []
|
||||||
self.ack_load_queue = Queue()
|
self.ack_write_queue: List[HiCacheAck] = []
|
||||||
|
|
||||||
self.stop_event = threading.Event()
|
self.stop_event = threading.Event()
|
||||||
self.write_buffer = TransferBuffer(self.stop_event)
|
self.write_buffer = TransferBuffer(self.stop_event)
|
||||||
@@ -366,16 +375,6 @@ class HiCacheController:
|
|||||||
self.write_stream = torch.cuda.Stream()
|
self.write_stream = torch.cuda.Stream()
|
||||||
self.load_stream = torch.cuda.Stream()
|
self.load_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
|
||||||
target=self.write_thread_func_direct, daemon=True
|
|
||||||
)
|
|
||||||
self.load_thread = threading.Thread(
|
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.write_thread.start()
|
|
||||||
self.load_thread.start()
|
|
||||||
|
|
||||||
if self.enable_storage:
|
if self.enable_storage:
|
||||||
self.prefetch_thread = threading.Thread(
|
self.prefetch_thread = threading.Thread(
|
||||||
target=self.prefetch_thread_func, daemon=True
|
target=self.prefetch_thread_func, daemon=True
|
||||||
@@ -432,15 +431,13 @@ class HiCacheController:
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.stop_event.set()
|
self.stop_event.set()
|
||||||
self.write_thread.join()
|
|
||||||
self.load_thread.join()
|
|
||||||
|
|
||||||
self.write_queue.queue.clear()
|
self.write_queue.clear()
|
||||||
self.load_queue.queue.clear()
|
self.load_queue.clear()
|
||||||
self.write_buffer.clear()
|
self.write_buffer.clear()
|
||||||
self.load_buffer.clear()
|
self.load_buffer.clear()
|
||||||
self.ack_write_queue.queue.clear()
|
self.ack_write_queue.clear()
|
||||||
self.ack_load_queue.queue.clear()
|
self.ack_load_queue.clear()
|
||||||
if self.enable_storage:
|
if self.enable_storage:
|
||||||
self.prefetch_thread.join()
|
self.prefetch_thread.join()
|
||||||
self.backup_thread.join()
|
self.backup_thread.join()
|
||||||
@@ -449,15 +446,7 @@ class HiCacheController:
|
|||||||
self.prefetch_revoke_queue.queue.clear()
|
self.prefetch_revoke_queue.queue.clear()
|
||||||
self.ack_backup_queue.queue.clear()
|
self.ack_backup_queue.queue.clear()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
|
||||||
target=self.write_thread_func_direct, daemon=True
|
|
||||||
)
|
|
||||||
self.load_thread = threading.Thread(
|
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
|
||||||
)
|
|
||||||
self.stop_event.clear()
|
self.stop_event.clear()
|
||||||
self.write_thread.start()
|
|
||||||
self.load_thread.start()
|
|
||||||
|
|
||||||
if self.enable_storage:
|
if self.enable_storage:
|
||||||
self.prefetch_thread = threading.Thread(
|
self.prefetch_thread = threading.Thread(
|
||||||
@@ -473,7 +462,7 @@ class HiCacheController:
|
|||||||
self,
|
self,
|
||||||
device_indices: torch.Tensor,
|
device_indices: torch.Tensor,
|
||||||
priority: Optional[int] = None,
|
priority: Optional[int] = None,
|
||||||
node_id: int = 0,
|
node_id: int = -1,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Back up KV caches from device memory to host memory.
|
Back up KV caches from device memory to host memory.
|
||||||
@@ -482,17 +471,46 @@ class HiCacheController:
|
|||||||
if host_indices is None:
|
if host_indices is None:
|
||||||
return None
|
return None
|
||||||
self.mem_pool_host.protect_write(host_indices)
|
self.mem_pool_host.protect_write(host_indices)
|
||||||
torch.cuda.current_stream().synchronize()
|
self.write_queue.append(
|
||||||
self.write_queue.put(
|
|
||||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||||
)
|
)
|
||||||
|
self.start_writing()
|
||||||
return host_indices
|
return host_indices
|
||||||
|
|
||||||
|
def start_writing(self) -> None:
|
||||||
|
if len(self.write_queue) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
op = CacheOperation.merge_ops(self.write_queue)
|
||||||
|
host_indices, device_indices = self.move_indices(op)
|
||||||
|
self.write_queue.clear()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event()
|
||||||
|
finish_event = torch.cuda.Event()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
with torch.cuda.stream(self.write_stream):
|
||||||
|
start_event.wait(self.write_stream)
|
||||||
|
self.mem_pool_host.backup_from_device_all_layer(
|
||||||
|
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||||
|
)
|
||||||
|
self.mem_pool_host.complete_io(op.host_indices)
|
||||||
|
finish_event.record()
|
||||||
|
# NOTE: We must save the host indices and device indices here,
|
||||||
|
# this is because we need to guarantee that these tensors are
|
||||||
|
# still alive when the write stream is executing.
|
||||||
|
if host_indices.is_cuda:
|
||||||
|
host_indices.record_stream(self.write_stream)
|
||||||
|
if device_indices.is_cuda:
|
||||||
|
device_indices.record_stream(self.write_stream)
|
||||||
|
|
||||||
|
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
host_indices: torch.Tensor,
|
host_indices: torch.Tensor,
|
||||||
priority: Optional[int] = None,
|
priority: Optional[int] = None,
|
||||||
node_id: int = 0,
|
node_id: int = -1,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Load KV caches from host memory to device memory.
|
Load KV caches from host memory to device memory.
|
||||||
@@ -501,17 +519,18 @@ class HiCacheController:
|
|||||||
if device_indices is None:
|
if device_indices is None:
|
||||||
return None
|
return None
|
||||||
self.mem_pool_host.protect_load(host_indices)
|
self.mem_pool_host.protect_load(host_indices)
|
||||||
# to ensure the device indices are ready before accessed by another CUDA stream
|
self.load_queue.append(
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
self.load_queue.put(
|
|
||||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||||
)
|
)
|
||||||
return device_indices
|
return device_indices
|
||||||
|
|
||||||
def move_indices(self, host_indices, device_indices):
|
def move_indices(self, op: CacheOperation):
|
||||||
|
host_indices, device_indices = op.host_indices, op.device_indices
|
||||||
# move indices to GPU if using kernels, to host if using direct indexing
|
# move indices to GPU if using kernels, to host if using direct indexing
|
||||||
if self.io_backend == "kernel":
|
if self.io_backend == "kernel":
|
||||||
return host_indices.to(self.mem_pool_device.device), device_indices
|
if not host_indices.is_cuda:
|
||||||
|
host_indices = host_indices.to(self.device, non_blocking=True)
|
||||||
|
return host_indices, device_indices
|
||||||
elif self.io_backend == "direct":
|
elif self.io_backend == "direct":
|
||||||
device_indices = device_indices.cpu()
|
device_indices = device_indices.cpu()
|
||||||
host_indices, idx = host_indices.sort()
|
host_indices, idx = host_indices.sort()
|
||||||
@@ -519,58 +538,20 @@ class HiCacheController:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported io backend")
|
raise ValueError(f"Unsupported io backend")
|
||||||
|
|
||||||
def write_thread_func_direct(self):
|
def start_loading(self) -> int:
|
||||||
"""
|
if len(self.load_queue) == 0:
|
||||||
Directly write through KV caches to host memory without buffering.
|
return -1
|
||||||
"""
|
|
||||||
torch.cuda.set_stream(self.write_stream)
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
try:
|
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
|
||||||
host_indices, device_indices = self.move_indices(
|
|
||||||
operation.host_indices, operation.device_indices
|
|
||||||
)
|
|
||||||
self.mem_pool_host.backup_from_device_all_layer(
|
|
||||||
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
|
||||||
)
|
|
||||||
self.write_stream.synchronize()
|
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
|
||||||
for node_id in operation.node_ids:
|
|
||||||
if node_id != 0:
|
|
||||||
self.ack_write_queue.put(node_id)
|
|
||||||
except Empty:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
def load_thread_func_layer_by_layer(self):
|
producer_id = self.layer_done_counter.update_producer()
|
||||||
"""
|
op = CacheOperation.merge_ops(self.load_queue)
|
||||||
Load KV caches from host memory to device memory layer by layer.
|
host_indices, device_indices = self.move_indices(op)
|
||||||
"""
|
self.load_queue.clear()
|
||||||
torch.cuda.set_stream(self.load_stream)
|
producer_event = self.layer_done_counter.events[producer_id]
|
||||||
while not self.stop_event.is_set():
|
producer_event.start_event.record()
|
||||||
self.load_cache_event.wait(timeout=1)
|
|
||||||
if not self.load_cache_event.is_set():
|
|
||||||
continue
|
|
||||||
self.load_cache_event.clear()
|
|
||||||
self.layer_done_counter.update_producer()
|
|
||||||
|
|
||||||
batch_operation = None
|
with torch.cuda.stream(self.load_stream):
|
||||||
while self.load_queue.qsize() > 0:
|
producer_event.start_event.wait(self.load_stream)
|
||||||
op = self.load_queue.get(block=True)
|
for i in range(self.layer_num):
|
||||||
if batch_operation is None:
|
|
||||||
batch_operation = op
|
|
||||||
else:
|
|
||||||
batch_operation.merge(op)
|
|
||||||
if batch_operation is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# start layer-wise KV cache transfer from CPU to GPU
|
|
||||||
self.layer_done_counter.reset()
|
|
||||||
host_indices, device_indices = self.move_indices(
|
|
||||||
batch_operation.host_indices, batch_operation.device_indices
|
|
||||||
)
|
|
||||||
for i in range(self.mem_pool_host.layer_num):
|
|
||||||
self.mem_pool_host.load_to_device_per_layer(
|
self.mem_pool_host.load_to_device_per_layer(
|
||||||
self.mem_pool_device,
|
self.mem_pool_device,
|
||||||
host_indices,
|
host_indices,
|
||||||
@@ -578,13 +559,24 @@ class HiCacheController:
|
|||||||
i,
|
i,
|
||||||
self.io_backend,
|
self.io_backend,
|
||||||
)
|
)
|
||||||
self.load_stream.synchronize()
|
producer_event.complete(i)
|
||||||
self.layer_done_counter.increment()
|
self.mem_pool_host.complete_io(op.host_indices)
|
||||||
|
# NOTE: We must save the host indices and device indices here,
|
||||||
|
# this is because we need to guarantee that these tensors are
|
||||||
|
# still alive when the load stream is executing.
|
||||||
|
if host_indices.is_cuda:
|
||||||
|
host_indices.record_stream(self.load_stream)
|
||||||
|
if device_indices.is_cuda:
|
||||||
|
device_indices.record_stream(self.load_stream)
|
||||||
|
|
||||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
self.ack_load_queue.append(
|
||||||
for node_id in batch_operation.node_ids:
|
HiCacheAck(
|
||||||
if node_id != 0:
|
start_event=producer_event.start_event,
|
||||||
self.ack_load_queue.put(node_id)
|
finish_event=producer_event.finish_event,
|
||||||
|
node_ids=op.node_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return producer_id
|
||||||
|
|
||||||
def evict_device(
|
def evict_device(
|
||||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||||
|
|||||||
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
is_prefill_only: bool = False
|
is_prefill_only: bool = False
|
||||||
|
|
||||||
# hicache pointer for synchronizing data loading from CPU to GPU
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
||||||
hicache_consumer_index: int = 0
|
hicache_consumer_index: int = -1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
|
|||||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||||
# If set, the output of the batch contains the hidden states of the run.
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
hicache_consumer_index: int = 0
|
hicache_consumer_index: int = -1
|
||||||
|
|
||||||
# Overlap event
|
# Overlap event
|
||||||
launch_done: Optional[threading.Event] = None
|
launch_done: Optional[threading.Event] = None
|
||||||
|
|||||||
@@ -1807,10 +1807,6 @@ class Scheduler(
|
|||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|
||||||
# update the consumer index of hicache to the running batch
|
|
||||||
self.tp_worker.set_hicache_consumer(
|
|
||||||
model_worker_batch.hicache_consumer_index
|
|
||||||
)
|
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
|
|||||||
@@ -12,10 +12,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
|||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -167,10 +171,10 @@ class TpModelWorker:
|
|||||||
|
|
||||||
self.hicache_layer_transfer_counter = None
|
self.hicache_layer_transfer_counter = None
|
||||||
|
|
||||||
def register_hicache_layer_transfer_counter(self, counter):
|
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||||
self.hicache_layer_transfer_counter = counter
|
self.hicache_layer_transfer_counter = counter
|
||||||
|
|
||||||
def set_hicache_consumer(self, consumer_index):
|
def set_hicache_consumer(self, consumer_index: int):
|
||||||
if self.hicache_layer_transfer_counter is not None:
|
if self.hicache_layer_transfer_counter is not None:
|
||||||
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
||||||
|
|
||||||
@@ -230,6 +234,9 @@ class TpModelWorker:
|
|||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
||||||
]:
|
]:
|
||||||
|
# update the consumer index of hicache to the running batch
|
||||||
|
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||||
|
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
|
||||||
pp_proxy_tensors = None
|
pp_proxy_tensors = None
|
||||||
|
|||||||
@@ -12,13 +12,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Launch threads
|
# Launch threads
|
||||||
self.input_queue = Queue()
|
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
||||||
self.output_queue = Queue()
|
self.output_queue = Queue()
|
||||||
self.forward_stream = torch.get_device_module(self.device).Stream()
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
||||||
self.forward_thread = threading.Thread(
|
self.forward_thread = threading.Thread(
|
||||||
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
self.hicache_layer_transfer_counter = None
|
self.hicache_layer_transfer_counter = None
|
||||||
|
|
||||||
def register_hicache_layer_transfer_counter(self, counter):
|
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||||
self.hicache_layer_transfer_counter = counter
|
self.hicache_layer_transfer_counter = counter
|
||||||
|
|
||||||
def set_hicache_consumer(self, consumer_index):
|
|
||||||
if self.hicache_layer_transfer_counter is not None:
|
|
||||||
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
|
||||||
|
|
||||||
def get_worker_info(self):
|
def get_worker_info(self):
|
||||||
return self.worker.get_worker_info()
|
return self.worker.get_worker_info()
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
|
|||||||
@DynamicGradMode()
|
@DynamicGradMode()
|
||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
batch_pt = 0
|
batch_pt = 0
|
||||||
batch_lists = [None] * 2
|
batch_lists: List = [None] * 2
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
||||||
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
|
|||||||
input_ids = model_worker_batch.input_ids
|
input_ids = model_worker_batch.input_ids
|
||||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||||
|
|
||||||
# update the consumer index of hicache to the running batch
|
|
||||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.worker.forward_batch_generation(
|
self.worker.forward_batch_generation(
|
||||||
|
|||||||
@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
|
|||||||
if write_back:
|
if write_back:
|
||||||
# blocking till all write back complete
|
# blocking till all write back complete
|
||||||
while len(self.ongoing_write_through) > 0:
|
while len(self.ongoing_write_through) > 0:
|
||||||
ack_id = self.cache_controller.ack_write_queue.get()
|
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
||||||
del self.ongoing_write_through[ack_id]
|
finish_event.synchronize()
|
||||||
|
for ack_id in ack_list:
|
||||||
|
del self.ongoing_write_through[ack_id]
|
||||||
|
self.cache_controller.ack_write_queue.clear()
|
||||||
|
assert len(self.ongoing_write_through) == 0
|
||||||
return
|
return
|
||||||
queue_size = torch.tensor(
|
|
||||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
||||||
)
|
if len(self.ongoing_write_through) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
finish_count = 0
|
||||||
|
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
||||||
|
if not finish_event.query():
|
||||||
|
break
|
||||||
|
finish_count += 1
|
||||||
|
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
# synchrnoize TP workers to make the same update to radix cache
|
# synchronize TP workers to make the same update to radix cache
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
queue_size,
|
queue_size,
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
group=self.tp_group,
|
group=self.tp_group,
|
||||||
)
|
)
|
||||||
for _ in range(queue_size.item()):
|
|
||||||
ack_id = self.cache_controller.ack_write_queue.get()
|
finish_count = int(queue_size.item())
|
||||||
backuped_node = self.ongoing_write_through[ack_id]
|
while finish_count > 0:
|
||||||
self.dec_lock_ref(backuped_node)
|
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
||||||
del self.ongoing_write_through[ack_id]
|
finish_event.synchronize()
|
||||||
if self.enable_storage:
|
for ack_id in ack_list:
|
||||||
self.write_backup_storage(backuped_node)
|
backuped_node = self.ongoing_write_through.pop(ack_id)
|
||||||
|
self.dec_lock_ref(backuped_node)
|
||||||
|
if self.enable_storage:
|
||||||
|
self.write_backup_storage(backuped_node)
|
||||||
|
finish_count -= 1
|
||||||
|
|
||||||
def loading_check(self):
|
def loading_check(self):
|
||||||
while not self.cache_controller.ack_load_queue.empty():
|
finish_count = 0
|
||||||
try:
|
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
||||||
ack_id = self.cache_controller.ack_load_queue.get_nowait()
|
if not finish_event.query():
|
||||||
start_node, end_node = self.ongoing_load_back[ack_id]
|
# the KV cache loading is still ongoing
|
||||||
self.dec_lock_ref(end_node)
|
|
||||||
while end_node != start_node:
|
|
||||||
assert end_node.loading
|
|
||||||
end_node.loading = False
|
|
||||||
end_node = end_node.parent
|
|
||||||
# clear the reference
|
|
||||||
del self.ongoing_load_back[ack_id]
|
|
||||||
except Exception:
|
|
||||||
break
|
break
|
||||||
|
finish_count += 1
|
||||||
|
# no need to sync across TP workers as batch forwarding is synced
|
||||||
|
for ack_id in ack_list:
|
||||||
|
end_node = self.ongoing_load_back.pop(ack_id)
|
||||||
|
self.dec_lock_ref(end_node)
|
||||||
|
|
||||||
|
# ACK until all events are processed
|
||||||
|
del self.cache_controller.ack_load_queue[:finish_count]
|
||||||
|
|
||||||
def evictable_size(self):
|
def evictable_size(self):
|
||||||
return self.evictable_size_
|
return self.evictable_size_
|
||||||
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
|
|||||||
# no sufficient GPU memory to load back KV caches
|
# no sufficient GPU memory to load back KV caches
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
|
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
||||||
offset = 0
|
offset = 0
|
||||||
for node in nodes_to_load:
|
for node in nodes_to_load:
|
||||||
node.value = device_indices[offset : offset + len(node.host_value)]
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
||||||
offset += len(node.host_value)
|
offset += len(node.host_value)
|
||||||
node.loading = True
|
|
||||||
self.evictable_size_ += len(device_indices)
|
self.evictable_size_ += len(device_indices)
|
||||||
self.inc_lock_ref(last_hit_node)
|
self.inc_lock_ref(last_hit_node)
|
||||||
|
|
||||||
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
|
|||||||
last_node,
|
last_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
def ready_to_load_host_cache(self):
|
def ready_to_load_host_cache(self) -> int:
|
||||||
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
"""
|
||||||
self.load_cache_event.set()
|
Notify the cache controller to start the KV cache loading.
|
||||||
return producer_index
|
Return the consumer index for the schedule batch manager to track.
|
||||||
|
"""
|
||||||
|
return self.cache_controller.start_loading()
|
||||||
|
|
||||||
def check_hicache_events(self):
|
def check_hicache_events(self):
|
||||||
self.writing_check()
|
self.writing_check()
|
||||||
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
|
|||||||
new_node.parent = child.parent
|
new_node.parent = child.parent
|
||||||
new_node.lock_ref = child.lock_ref
|
new_node.lock_ref = child.lock_ref
|
||||||
new_node.key = child.key[:split_len]
|
new_node.key = child.key[:split_len]
|
||||||
new_node.loading = child.loading
|
|
||||||
new_node.hit_count = child.hit_count
|
new_node.hit_count = child.hit_count
|
||||||
|
|
||||||
# split value and host value if exists
|
# split value and host value if exists
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
|
|||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GB = 1024 * 1024 * 1024
|
GB = 1024 * 1024 * 1024
|
||||||
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
||||||
self.layer_transfer_counter = layer_transfer_counter
|
self.layer_transfer_counter = layer_transfer_counter
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
|
|||||||
return len(self.free_slots)
|
return len(self.free_slots)
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def alloc(self, need_size: int) -> torch.Tensor:
|
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
||||||
assert (
|
assert (
|
||||||
need_size % self.page_size == 0
|
need_size % self.page_size == 0
|
||||||
), "The requested size should be a multiple of the page size."
|
), "The requested size should be a multiple of the page size."
|
||||||
|
|||||||
@@ -53,8 +53,6 @@ class TreeNode:
|
|||||||
self.last_access_time = time.monotonic()
|
self.last_access_time = time.monotonic()
|
||||||
|
|
||||||
self.hit_count = 0
|
self.hit_count = 0
|
||||||
# indicating the node is loading KV cache from host
|
|
||||||
self.loading = False
|
|
||||||
# indicating the node is locked to protect from eviction
|
# indicating the node is locked to protect from eviction
|
||||||
# incremented when the node is referenced by a storage operation
|
# incremented when the node is referenced by a storage operation
|
||||||
self.host_ref_counter = 0
|
self.host_ref_counter = 0
|
||||||
|
|||||||
@@ -60,8 +60,6 @@ class TreeNode:
|
|||||||
self.last_access_time = time.monotonic()
|
self.last_access_time = time.monotonic()
|
||||||
|
|
||||||
self.hit_count = 0
|
self.hit_count = 0
|
||||||
# indicating the node is loading KV cache from host
|
|
||||||
self.loading = False
|
|
||||||
# store the host indices of KV cache
|
# store the host indices of KV cache
|
||||||
self.host_value = None
|
self.host_value = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user