Hierarchical Caching for SGLang (#2693)

Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Zhiqiang Xie
2025-02-23 21:56:30 -08:00
committed by GitHub
parent 4d2a88bdff
commit 6c7a152c5a
7 changed files with 732 additions and 91 deletions

View File

@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -15,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import concurrent.futures
import logging
import math
import threading
from queue import PriorityQueue, Queue
from typing import Optional
from queue import Empty, Full, PriorityQueue, Queue
from typing import List, Optional
import torch
@@ -55,6 +55,27 @@ class CacheOperation:
self.priority = min(self.priority, other.priority)
self.node_ids.extend(other.node_ids)
def split(self, factor) -> List["CacheOperation"]:
# split an operation into smaller operations to reduce the size of intermediate buffers
if factor <= 1:
return [self]
chunk_size = math.ceil(len(self.host_indices) / factor)
split_ops = []
for i in range(0, len(self.host_indices), chunk_size):
split_ops.append(
CacheOperation(
host_indices=self.host_indices[i : i + chunk_size],
device_indices=self.device_indices[i : i + chunk_size],
node_id=0,
)
)
# Inherit the node_ids on the final chunk
if split_ops:
split_ops[-1].node_ids = self.node_ids
return split_ops
def __lt__(self, other: "CacheOperation"):
return self.priority < other.priority
@@ -64,7 +85,10 @@ class TransferBuffer:
Overlapping buffer preparation and transfer operations to improve throughput.
"""
def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
def __init__(
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
) -> None:
self.stop_event = stop_event
self.buffers = Queue(maxsize=buffer_count)
# todo: adjust the buffer size based on throughput profile of the system
self.max_buffer_size = max_buffer_size
@@ -75,15 +99,29 @@ class TransferBuffer:
def empty(self) -> bool:
return self.buffers.empty()
def put(self, item, block=True) -> None:
self.buffers.put(item, block=block)
def put(self, item, block=True, timeout=1) -> None:
while not self.stop_event.is_set():
try:
self.buffers.put(item, block=block, timeout=timeout)
break
except Full:
if not block:
break
continue
except Exception as e:
logger.error(e)
def get(self, block=True) -> Optional[CacheOperation]:
def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
try:
return self.buffers.get(block=block)
return self.buffers.get(block=block, timeout=timeout)
except Empty:
return None
except Exception as e:
logger.error(e)
def clear(self):
self.buffers.queue.clear()
class HiCacheController:
@@ -111,8 +149,11 @@ class HiCacheController:
self.ack_write_queue = Queue()
self.ack_load_queue = Queue()
self.write_buffer = TransferBuffer()
self.load_buffer = TransferBuffer()
self.stop_event = threading.Event()
self.write_buffer = TransferBuffer(self.stop_event)
self.load_buffer = TransferBuffer(
self.stop_event, buffer_count=10, max_buffer_size=100
)
self.write_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream()
@@ -126,6 +167,28 @@ class HiCacheController:
self.write_thread.start()
self.load_thread.start()
def reset(self):
self.stop_event.set()
self.write_thread.join()
self.load_thread.join()
self.write_queue.queue.clear()
self.load_queue.queue.clear()
self.write_buffer.clear()
self.load_buffer.clear()
self.ack_write_queue.queue.clear()
self.ack_load_queue.queue.clear()
self.write_thread = threading.Thread(
target=self.write_thread_func_buffer, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True
)
self.stop_event.clear()
self.write_thread.start()
self.load_thread.start()
def write(
self,
device_indices: torch.Tensor,
@@ -138,10 +201,10 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None:
return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority)
)
self.mem_pool_host.protect_write(host_indices)
return host_indices
def load(
@@ -156,10 +219,10 @@ class HiCacheController:
device_indices = self.mem_pool_device.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority)
)
self.mem_pool_host.protect_load(host_indices)
return device_indices
def write_thread_func_direct(self):
@@ -167,16 +230,19 @@ class HiCacheController:
Directly write through KV caches to host memory without buffering.
"""
with torch.cuda.stream(self.write_stream):
while True:
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True)
operation = self.write_queue.get(block=True, timeout=1)
operation.data = self.mem_pool_device.get_flat_data(
operation.device_indices
)
self.mem_pool_host.transfer(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
self.ack_write_queue.put(node_id)
if node_id != 0:
self.ack_write_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
@@ -185,9 +251,10 @@ class HiCacheController:
Directly load KV caches from host memory to device memory without buffering.
"""
with torch.cuda.stream(self.load_stream):
while True:
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True)
operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
@@ -196,7 +263,10 @@ class HiCacheController:
)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
self.ack_load_queue.put(node_id)
if node_id != 0:
self.ack_load_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
@@ -204,39 +274,98 @@ class HiCacheController:
"""
Auxiliary function to prepare the buffer for write operations.
"""
def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
self.mem_pool_host.device
)
self.write_buffer.put(op_)
return op_
buffer = None
while True:
try:
operation = self.write_queue.get(block=True)
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
assert (
buffer.device_indices.is_cuda
), "Device indices should be on GPU"
buffer.data = self.mem_pool_device.get_flat_data(
buffer.device_indices
).contiguous()
self.write_buffer.put(buffer, block=True)
buffer = None
except Exception as e:
logger.error(e)
with torch.cuda.stream(self.write_stream):
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
factor = (
len(operation.device_indices)
// self.write_buffer.max_buffer_size
)
if factor >= 1:
if buffer is not None:
_to_op(buffer)
buffer = None
if factor < 2:
_to_op(operation)
else:
split_ops = operation.split(factor)
for op_ in split_ops:
_to_op(op_)
continue
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
_to_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
def load_aux_func(self):
"""
Auxiliary function to prepare the buffer for load operations.
"""
def _pin_op(op_, put=True):
op_.data = (
self.mem_pool_host.get_flat_data(op_.host_indices)
.contiguous()
.pin_memory()
)
if put:
self.load_buffer.put(op_)
return op_
buffer = None
while True:
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True)
operation = self.load_queue.get(block=True, timeout=1)
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
if factor >= 1:
if buffer is not None:
_pin_op(buffer)
buffer = None
if factor < 2:
_pin_op(operation)
else:
split_ops = operation.split(factor)
split_args = [(op_, True) for op_ in split_ops[:-1]]
split_args.append((split_ops[-1], False))
# Spawn threads to pin each op concurrently
with concurrent.futures.ThreadPoolExecutor() as executor:
pinned_ops = list(
executor.map(
lambda x: _pin_op(x[0], put=x[1]), split_args
)
)
# preserve the order of last op to ensure correct ack
self.load_buffer.put(pinned_ops[-1])
continue
if buffer is None:
buffer = operation
else:
@@ -246,41 +375,43 @@ class HiCacheController:
or self.load_queue.empty()
or self.load_buffer.empty()
):
buffer.data = (
self.mem_pool_host.get_flat_data(buffer.host_indices)
.contiguous()
.pin_memory()
)
self.load_buffer.put(buffer, block=True)
_pin_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start()
with torch.cuda.stream(self.write_stream):
while True:
operation = self.write_buffer.get()
if operation is None:
continue
self.mem_pool_host.transfer(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
while not self.stop_event.is_set():
operation = self.write_buffer.get()
if operation is None:
continue
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id)
aux_thread.join()
def load_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start()
with torch.cuda.stream(self.load_stream):
while True:
while not self.stop_event.is_set():
operation = self.load_buffer.get()
if operation is None:
continue
self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
self.ack_load_queue.put(node_id)
if node_id != 0:
self.ack_load_queue.put(node_id)
aux_thread.join()
def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor