Hierarchical Caching for SGLang (#2693)
Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@@ -15,10 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from queue import PriorityQueue, Queue
|
||||
from typing import Optional
|
||||
from queue import Empty, Full, PriorityQueue, Queue
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -55,6 +55,27 @@ class CacheOperation:
|
||||
self.priority = min(self.priority, other.priority)
|
||||
self.node_ids.extend(other.node_ids)
|
||||
|
||||
def split(self, factor) -> List["CacheOperation"]:
|
||||
# split an operation into smaller operations to reduce the size of intermediate buffers
|
||||
if factor <= 1:
|
||||
return [self]
|
||||
|
||||
chunk_size = math.ceil(len(self.host_indices) / factor)
|
||||
split_ops = []
|
||||
for i in range(0, len(self.host_indices), chunk_size):
|
||||
split_ops.append(
|
||||
CacheOperation(
|
||||
host_indices=self.host_indices[i : i + chunk_size],
|
||||
device_indices=self.device_indices[i : i + chunk_size],
|
||||
node_id=0,
|
||||
)
|
||||
)
|
||||
# Inherit the node_ids on the final chunk
|
||||
if split_ops:
|
||||
split_ops[-1].node_ids = self.node_ids
|
||||
|
||||
return split_ops
|
||||
|
||||
def __lt__(self, other: "CacheOperation"):
|
||||
return self.priority < other.priority
|
||||
|
||||
@@ -64,7 +85,10 @@ class TransferBuffer:
|
||||
Overlapping buffer preparation and transfer operations to improve throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
|
||||
def __init__(
|
||||
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
|
||||
) -> None:
|
||||
self.stop_event = stop_event
|
||||
self.buffers = Queue(maxsize=buffer_count)
|
||||
# todo: adjust the buffer size based on throughput profile of the system
|
||||
self.max_buffer_size = max_buffer_size
|
||||
@@ -75,15 +99,29 @@ class TransferBuffer:
|
||||
def empty(self) -> bool:
|
||||
return self.buffers.empty()
|
||||
|
||||
def put(self, item, block=True) -> None:
|
||||
self.buffers.put(item, block=block)
|
||||
def put(self, item, block=True, timeout=1) -> None:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
self.buffers.put(item, block=block, timeout=timeout)
|
||||
break
|
||||
except Full:
|
||||
if not block:
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def get(self, block=True) -> Optional[CacheOperation]:
|
||||
def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
|
||||
try:
|
||||
return self.buffers.get(block=block)
|
||||
return self.buffers.get(block=block, timeout=timeout)
|
||||
except Empty:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def clear(self):
|
||||
self.buffers.queue.clear()
|
||||
|
||||
|
||||
class HiCacheController:
|
||||
|
||||
@@ -111,8 +149,11 @@ class HiCacheController:
|
||||
self.ack_write_queue = Queue()
|
||||
self.ack_load_queue = Queue()
|
||||
|
||||
self.write_buffer = TransferBuffer()
|
||||
self.load_buffer = TransferBuffer()
|
||||
self.stop_event = threading.Event()
|
||||
self.write_buffer = TransferBuffer(self.stop_event)
|
||||
self.load_buffer = TransferBuffer(
|
||||
self.stop_event, buffer_count=10, max_buffer_size=100
|
||||
)
|
||||
|
||||
self.write_stream = torch.cuda.Stream()
|
||||
self.load_stream = torch.cuda.Stream()
|
||||
@@ -126,6 +167,28 @@ class HiCacheController:
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
def reset(self):
|
||||
self.stop_event.set()
|
||||
self.write_thread.join()
|
||||
self.load_thread.join()
|
||||
|
||||
self.write_queue.queue.clear()
|
||||
self.load_queue.queue.clear()
|
||||
self.write_buffer.clear()
|
||||
self.load_buffer.clear()
|
||||
self.ack_write_queue.queue.clear()
|
||||
self.ack_load_queue.queue.clear()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.stop_event.clear()
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
def write(
|
||||
self,
|
||||
device_indices: torch.Tensor,
|
||||
@@ -138,10 +201,10 @@ class HiCacheController:
|
||||
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
||||
if host_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_write(host_indices)
|
||||
self.write_queue.put(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
self.mem_pool_host.protect_write(host_indices)
|
||||
return host_indices
|
||||
|
||||
def load(
|
||||
@@ -156,10 +219,10 @@ class HiCacheController:
|
||||
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
self.load_queue.put(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
return device_indices
|
||||
|
||||
def write_thread_func_direct(self):
|
||||
@@ -167,16 +230,19 @@ class HiCacheController:
|
||||
Directly write through KV caches to host memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True)
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
operation.data = self.mem_pool_device.get_flat_data(
|
||||
operation.device_indices
|
||||
)
|
||||
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
self.ack_write_queue.put(node_id)
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
@@ -185,9 +251,10 @@ class HiCacheController:
|
||||
Directly load KV caches from host memory to device memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True)
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
# time.sleep(18e-6 * len(operation.host_indices))
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
@@ -196,7 +263,10 @@ class HiCacheController:
|
||||
)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
self.ack_load_queue.put(node_id)
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
@@ -204,39 +274,98 @@ class HiCacheController:
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
"""
|
||||
|
||||
def _to_op(op_):
|
||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
||||
self.mem_pool_host.device
|
||||
)
|
||||
self.write_buffer.put(op_)
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
while True:
|
||||
try:
|
||||
operation = self.write_queue.get(block=True)
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
assert (
|
||||
buffer.device_indices.is_cuda
|
||||
), "Device indices should be on GPU"
|
||||
buffer.data = self.mem_pool_device.get_flat_data(
|
||||
buffer.device_indices
|
||||
).contiguous()
|
||||
self.write_buffer.put(buffer, block=True)
|
||||
buffer = None
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
factor = (
|
||||
len(operation.device_indices)
|
||||
// self.write_buffer.max_buffer_size
|
||||
)
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_to_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
for op_ in split_ops:
|
||||
_to_op(op_)
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_aux_func(self):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for load operations.
|
||||
"""
|
||||
|
||||
def _pin_op(op_, put=True):
|
||||
op_.data = (
|
||||
self.mem_pool_host.get_flat_data(op_.host_indices)
|
||||
.contiguous()
|
||||
.pin_memory()
|
||||
)
|
||||
if put:
|
||||
self.load_buffer.put(op_)
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True)
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_pin_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_pin_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
||||
split_args.append((split_ops[-1], False))
|
||||
# Spawn threads to pin each op concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
pinned_ops = list(
|
||||
executor.map(
|
||||
lambda x: _pin_op(x[0], put=x[1]), split_args
|
||||
)
|
||||
)
|
||||
# preserve the order of last op to ensure correct ack
|
||||
self.load_buffer.put(pinned_ops[-1])
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
@@ -246,41 +375,43 @@ class HiCacheController:
|
||||
or self.load_queue.empty()
|
||||
or self.load_buffer.empty()
|
||||
):
|
||||
buffer.data = (
|
||||
self.mem_pool_host.get_flat_data(buffer.host_indices)
|
||||
.contiguous()
|
||||
.pin_memory()
|
||||
)
|
||||
self.load_buffer.put(buffer, block=True)
|
||||
_pin_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def write_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while True:
|
||||
operation = self.write_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.write_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def load_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.load_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
self.ack_load_queue.put(node_id)
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def evict_device(
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
|
||||
Reference in New Issue
Block a user