Hierarchical Caching for SGLang (#2693)
Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@@ -15,10 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from queue import PriorityQueue, Queue
|
||||
from typing import Optional
|
||||
from queue import Empty, Full, PriorityQueue, Queue
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -55,6 +55,27 @@ class CacheOperation:
|
||||
self.priority = min(self.priority, other.priority)
|
||||
self.node_ids.extend(other.node_ids)
|
||||
|
||||
def split(self, factor) -> List["CacheOperation"]:
|
||||
# split an operation into smaller operations to reduce the size of intermediate buffers
|
||||
if factor <= 1:
|
||||
return [self]
|
||||
|
||||
chunk_size = math.ceil(len(self.host_indices) / factor)
|
||||
split_ops = []
|
||||
for i in range(0, len(self.host_indices), chunk_size):
|
||||
split_ops.append(
|
||||
CacheOperation(
|
||||
host_indices=self.host_indices[i : i + chunk_size],
|
||||
device_indices=self.device_indices[i : i + chunk_size],
|
||||
node_id=0,
|
||||
)
|
||||
)
|
||||
# Inherit the node_ids on the final chunk
|
||||
if split_ops:
|
||||
split_ops[-1].node_ids = self.node_ids
|
||||
|
||||
return split_ops
|
||||
|
||||
def __lt__(self, other: "CacheOperation"):
|
||||
return self.priority < other.priority
|
||||
|
||||
@@ -64,7 +85,10 @@ class TransferBuffer:
|
||||
Overlapping buffer preparation and transfer operations to improve throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
|
||||
def __init__(
|
||||
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
|
||||
) -> None:
|
||||
self.stop_event = stop_event
|
||||
self.buffers = Queue(maxsize=buffer_count)
|
||||
# todo: adjust the buffer size based on throughput profile of the system
|
||||
self.max_buffer_size = max_buffer_size
|
||||
@@ -75,15 +99,29 @@ class TransferBuffer:
|
||||
def empty(self) -> bool:
|
||||
return self.buffers.empty()
|
||||
|
||||
def put(self, item, block=True) -> None:
|
||||
self.buffers.put(item, block=block)
|
||||
def put(self, item, block=True, timeout=1) -> None:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
self.buffers.put(item, block=block, timeout=timeout)
|
||||
break
|
||||
except Full:
|
||||
if not block:
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def get(self, block=True) -> Optional[CacheOperation]:
|
||||
def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
|
||||
try:
|
||||
return self.buffers.get(block=block)
|
||||
return self.buffers.get(block=block, timeout=timeout)
|
||||
except Empty:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def clear(self):
|
||||
self.buffers.queue.clear()
|
||||
|
||||
|
||||
class HiCacheController:
|
||||
|
||||
@@ -111,8 +149,11 @@ class HiCacheController:
|
||||
self.ack_write_queue = Queue()
|
||||
self.ack_load_queue = Queue()
|
||||
|
||||
self.write_buffer = TransferBuffer()
|
||||
self.load_buffer = TransferBuffer()
|
||||
self.stop_event = threading.Event()
|
||||
self.write_buffer = TransferBuffer(self.stop_event)
|
||||
self.load_buffer = TransferBuffer(
|
||||
self.stop_event, buffer_count=10, max_buffer_size=100
|
||||
)
|
||||
|
||||
self.write_stream = torch.cuda.Stream()
|
||||
self.load_stream = torch.cuda.Stream()
|
||||
@@ -126,6 +167,28 @@ class HiCacheController:
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
def reset(self):
|
||||
self.stop_event.set()
|
||||
self.write_thread.join()
|
||||
self.load_thread.join()
|
||||
|
||||
self.write_queue.queue.clear()
|
||||
self.load_queue.queue.clear()
|
||||
self.write_buffer.clear()
|
||||
self.load_buffer.clear()
|
||||
self.ack_write_queue.queue.clear()
|
||||
self.ack_load_queue.queue.clear()
|
||||
|
||||
self.write_thread = threading.Thread(
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.stop_event.clear()
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
|
||||
def write(
|
||||
self,
|
||||
device_indices: torch.Tensor,
|
||||
@@ -138,10 +201,10 @@ class HiCacheController:
|
||||
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
||||
if host_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_write(host_indices)
|
||||
self.write_queue.put(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
self.mem_pool_host.protect_write(host_indices)
|
||||
return host_indices
|
||||
|
||||
def load(
|
||||
@@ -156,10 +219,10 @@ class HiCacheController:
|
||||
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
self.load_queue.put(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
return device_indices
|
||||
|
||||
def write_thread_func_direct(self):
|
||||
@@ -167,16 +230,19 @@ class HiCacheController:
|
||||
Directly write through KV caches to host memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True)
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
operation.data = self.mem_pool_device.get_flat_data(
|
||||
operation.device_indices
|
||||
)
|
||||
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
self.ack_write_queue.put(node_id)
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
@@ -185,9 +251,10 @@ class HiCacheController:
|
||||
Directly load KV caches from host memory to device memory without buffering.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True)
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
# time.sleep(18e-6 * len(operation.host_indices))
|
||||
operation.data = self.mem_pool_host.get_flat_data(
|
||||
operation.host_indices
|
||||
)
|
||||
@@ -196,7 +263,10 @@ class HiCacheController:
|
||||
)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
self.ack_load_queue.put(node_id)
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
@@ -204,39 +274,98 @@ class HiCacheController:
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
"""
|
||||
|
||||
def _to_op(op_):
|
||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
||||
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
||||
self.mem_pool_host.device
|
||||
)
|
||||
self.write_buffer.put(op_)
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
while True:
|
||||
try:
|
||||
operation = self.write_queue.get(block=True)
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
assert (
|
||||
buffer.device_indices.is_cuda
|
||||
), "Device indices should be on GPU"
|
||||
buffer.data = self.mem_pool_device.get_flat_data(
|
||||
buffer.device_indices
|
||||
).contiguous()
|
||||
self.write_buffer.put(buffer, block=True)
|
||||
buffer = None
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.write_queue.get(block=True, timeout=1)
|
||||
factor = (
|
||||
len(operation.device_indices)
|
||||
// self.write_buffer.max_buffer_size
|
||||
)
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_to_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
for op_ in split_ops:
|
||||
_to_op(op_)
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
buffer.merge(operation)
|
||||
if (
|
||||
no_wait
|
||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
||||
or self.write_queue.empty()
|
||||
or self.write_buffer.empty()
|
||||
):
|
||||
_to_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_aux_func(self):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for load operations.
|
||||
"""
|
||||
|
||||
def _pin_op(op_, put=True):
|
||||
op_.data = (
|
||||
self.mem_pool_host.get_flat_data(op_.host_indices)
|
||||
.contiguous()
|
||||
.pin_memory()
|
||||
)
|
||||
if put:
|
||||
self.load_buffer.put(op_)
|
||||
return op_
|
||||
|
||||
buffer = None
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.load_queue.get(block=True)
|
||||
operation = self.load_queue.get(block=True, timeout=1)
|
||||
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
||||
|
||||
if factor >= 1:
|
||||
if buffer is not None:
|
||||
_pin_op(buffer)
|
||||
buffer = None
|
||||
|
||||
if factor < 2:
|
||||
_pin_op(operation)
|
||||
else:
|
||||
split_ops = operation.split(factor)
|
||||
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
||||
split_args.append((split_ops[-1], False))
|
||||
# Spawn threads to pin each op concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
pinned_ops = list(
|
||||
executor.map(
|
||||
lambda x: _pin_op(x[0], put=x[1]), split_args
|
||||
)
|
||||
)
|
||||
# preserve the order of last op to ensure correct ack
|
||||
self.load_buffer.put(pinned_ops[-1])
|
||||
continue
|
||||
|
||||
if buffer is None:
|
||||
buffer = operation
|
||||
else:
|
||||
@@ -246,41 +375,43 @@ class HiCacheController:
|
||||
or self.load_queue.empty()
|
||||
or self.load_buffer.empty()
|
||||
):
|
||||
buffer.data = (
|
||||
self.mem_pool_host.get_flat_data(buffer.host_indices)
|
||||
.contiguous()
|
||||
.pin_memory()
|
||||
)
|
||||
self.load_buffer.put(buffer, block=True)
|
||||
_pin_op(buffer)
|
||||
buffer = None
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def write_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
with torch.cuda.stream(self.write_stream):
|
||||
while True:
|
||||
operation = self.write_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.write_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_write_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def load_thread_func_buffer(self):
|
||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
||||
aux_thread.start()
|
||||
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while True:
|
||||
while not self.stop_event.is_set():
|
||||
operation = self.load_buffer.get()
|
||||
if operation is None:
|
||||
continue
|
||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
for node_id in operation.node_ids:
|
||||
self.ack_load_queue.put(node_id)
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
aux_thread.join()
|
||||
|
||||
def evict_device(
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
|
||||
@@ -82,6 +82,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
@@ -300,16 +301,24 @@ class Scheduler:
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
disable=server_args.disable_radix_cache,
|
||||
self.tree_cache = (
|
||||
HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
)
|
||||
if self.enable_hierarchical_cache
|
||||
else RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.staging_reqs = {}
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
@@ -953,6 +962,30 @@ class Scheduler:
|
||||
break
|
||||
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
|
||||
if self.enable_hierarchical_cache and req.last_node is not None:
|
||||
if req.last_node.evicted:
|
||||
# loading KV cache for the request
|
||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
||||
req.last_node,
|
||||
req.prefix_indices,
|
||||
adder.rem_total_tokens,
|
||||
)
|
||||
if req.last_node.loading:
|
||||
# to prevent frequent cache invalidation
|
||||
if req.rid in self.staging_reqs:
|
||||
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self.staging_reqs[req.rid] = req.last_node
|
||||
continue
|
||||
elif req.last_node.loading:
|
||||
if not self.tree_cache.loading_complete(req.last_node):
|
||||
continue
|
||||
|
||||
if req.rid in self.staging_reqs:
|
||||
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
||||
del self.staging_reqs[req.rid]
|
||||
|
||||
res = adder.add_one_req(req)
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
|
||||
394
python/sglang/srt/mem_cache/hiradix_cache.py
Normal file
394
python/sglang/srt/mem_cache/hiradix_cache.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import heapq
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
BaseTokenToKVPool,
|
||||
MLATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HiRadixCache(RadixCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool, self.token_to_kv_pool_host
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
self.ongoing_write_through = {}
|
||||
# record the node segments with ongoing load back
|
||||
self.ongoing_load_back = {}
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = 1
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
|
||||
|
||||
def reset(self):
|
||||
TreeNode.counter = 0
|
||||
self.cache_controller.reset()
|
||||
self.token_to_kv_pool_host.clear()
|
||||
super().reset()
|
||||
|
||||
def get_height(self, node: TreeNode):
|
||||
height = 0
|
||||
while node != self.root_node:
|
||||
node = node.parent
|
||||
height += 1
|
||||
return height
|
||||
|
||||
def write_backup(self, node: TreeNode):
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is None:
|
||||
self.evict_host(len(node.value))
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is not None:
|
||||
node.host_value = host_indices
|
||||
self.ongoing_write_through[node.id] = node
|
||||
self.inc_lock_ref(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
return len(host_indices)
|
||||
|
||||
def inc_hit_count(self, node: TreeNode):
|
||||
if self.cache_controller.write_policy != "write_through_selective":
|
||||
return
|
||||
node.hit_count += 1
|
||||
if node.host_value is None and node.hit_count > self.write_through_threshold:
|
||||
self.write_backup(node)
|
||||
node.hit_count = 0
|
||||
|
||||
def writing_check(self):
|
||||
while not self.cache_controller.ack_write_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_write_queue.get_nowait()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
# clear the reference
|
||||
del self.ongoing_write_through[ack_id]
|
||||
except Exception:
|
||||
break
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_load_queue.get_nowait()
|
||||
start_node, end_node = self.ongoing_load_back[ack_id]
|
||||
self.dec_lock_ref(end_node)
|
||||
while end_node != start_node:
|
||||
assert end_node.loading
|
||||
end_node.loading = False
|
||||
end_node = end_node.parent
|
||||
# clear the reference
|
||||
del self.ongoing_load_back[ack_id]
|
||||
except Exception:
|
||||
break
|
||||
|
||||
def evictable_size(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
return self.evictable_size_
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback=None):
|
||||
leaves = self._collect_leaves_device()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
pending_nodes = []
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
if x.host_value is None:
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
num_evicted += self.write_backup(x)
|
||||
elif self.cache_controller.write_policy == "write_through_selective":
|
||||
num_evicted += self._evict_write_through_selective(x)
|
||||
else:
|
||||
assert (
|
||||
self.cache_controller.write_policy != "write_through"
|
||||
), "write_through should be inclusive"
|
||||
raise NotImplementedError
|
||||
else:
|
||||
num_evicted += self._evict_write_through(x)
|
||||
|
||||
for child in x.parent.children.values():
|
||||
if child in pending_nodes:
|
||||
continue
|
||||
if not child.evicted:
|
||||
break
|
||||
else:
|
||||
# all children are evicted or no children
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
# blocking till all write back complete
|
||||
while len(self.ongoing_write_through) > 0:
|
||||
self.writing_check()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _evict_write_through(self, node: TreeNode):
|
||||
# evict a node already written to host
|
||||
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
||||
assert num_evicted > 0
|
||||
self.evictable_size_ -= num_evicted
|
||||
node.value = None
|
||||
return num_evicted
|
||||
|
||||
def _evict_write_through_selective(self, node: TreeNode):
|
||||
# evict a node not initiated write to host
|
||||
self.cache_controller.mem_pool_device.free(node.value)
|
||||
num_evicted = len(node.value)
|
||||
self._delete_leaf(node)
|
||||
return num_evicted
|
||||
|
||||
def evict_host(self, num_tokens: int):
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
if x == self.root_node:
|
||||
break
|
||||
# only evict the host value of evicted nodes
|
||||
if not x.evicted:
|
||||
continue
|
||||
assert x.lock_ref == 0 and x.host_value is not None
|
||||
|
||||
assert self.cache_controller.evict_host(x.host_value) > 0
|
||||
for k, v in x.parent.children.items():
|
||||
if v == x:
|
||||
break
|
||||
del x.parent.children[k]
|
||||
|
||||
if len(x.parent.children) == 0 and x.parent.evicted:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def load_back(
|
||||
self, node: TreeNode, mem_quota: Optional[int] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
# todo: more loading policies
|
||||
|
||||
last_hit_node = node
|
||||
nodes_to_load = []
|
||||
while node.evicted:
|
||||
assert (
|
||||
node.backuped
|
||||
), "No backup available on evicted nodes, should not happen"
|
||||
nodes_to_load.insert(0, node)
|
||||
node = node.parent
|
||||
else:
|
||||
ancester_node = node
|
||||
|
||||
# protect the ancestor nodes from eviction
|
||||
delta = self.inc_lock_ref(ancester_node)
|
||||
|
||||
# load it all or not at all
|
||||
host_indices = torch.cat([n.host_value for n in nodes_to_load])
|
||||
if len(host_indices) < self.load_back_threshold or (
|
||||
len(host_indices) > mem_quota + delta if mem_quota is not None else False
|
||||
):
|
||||
# skip loading back if the total size is too small or exceeding the memory quota
|
||||
self.dec_lock_ref(ancester_node)
|
||||
return None
|
||||
|
||||
device_indices = self.cache_controller.load(
|
||||
host_indices=host_indices, node_id=last_hit_node.id
|
||||
)
|
||||
if device_indices is None:
|
||||
self.evict(len(host_indices))
|
||||
device_indices = self.cache_controller.load(
|
||||
host_indices=host_indices, node_id=last_hit_node.id
|
||||
)
|
||||
self.dec_lock_ref(ancester_node)
|
||||
if device_indices is None:
|
||||
# no sufficient GPU memory to load back KV caches
|
||||
return None
|
||||
|
||||
self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
|
||||
offset = 0
|
||||
for node in nodes_to_load:
|
||||
node.value = device_indices[offset : offset + len(node.host_value)]
|
||||
offset += len(node.host_value)
|
||||
node.loading = True
|
||||
self.evictable_size_ += len(device_indices)
|
||||
self.inc_lock_ref(last_hit_node)
|
||||
|
||||
return device_indices
|
||||
|
||||
def loading_complete(self, node: TreeNode):
|
||||
self.loading_check()
|
||||
return node.loading == False
|
||||
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
prefix_indices: torch.Tensor,
|
||||
mem_quota: Optional[int] = None,
|
||||
):
|
||||
assert (
|
||||
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
||||
), "indices of device kV caches should be on GPU"
|
||||
if last_node.evicted:
|
||||
loading_values = self.load_back(last_node, mem_quota)
|
||||
if loading_values is not None:
|
||||
prefix_indices = (
|
||||
loading_values
|
||||
if len(prefix_indices) == 0
|
||||
else torch.cat([prefix_indices, loading_values])
|
||||
)
|
||||
logger.debug(
|
||||
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
||||
)
|
||||
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def _match_prefix_helper(
|
||||
self, node: TreeNode, key: List, value, last_node: TreeNode
|
||||
):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
return
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
prefix_len = _key_match(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
self.inc_hit_count(new_node)
|
||||
if not new_node.evicted:
|
||||
value.append(new_node.value)
|
||||
last_node[0] = new_node
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
if not child.evicted:
|
||||
value.append(child.value)
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len]: child}
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
new_node.loading = child.loading
|
||||
|
||||
# split value and host value if exists
|
||||
if child.evicted:
|
||||
new_node.value = None
|
||||
else:
|
||||
new_node.value = child.value[:split_len]
|
||||
child.value = child.value[split_len:]
|
||||
if child.host_value is not None:
|
||||
new_node.host_value = child.host_value[:split_len]
|
||||
child.host_value = child.host_value[split_len:]
|
||||
child.parent = new_node
|
||||
child.key = child.key[split_len:]
|
||||
new_node.parent.children[key[0]] = new_node
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
prefix_len = _key_match(child.key, key)
|
||||
|
||||
if prefix_len == len(child.key):
|
||||
if child.evicted:
|
||||
# change the reference if the node is evicted
|
||||
# this often happens in the case of KV cache recomputation
|
||||
child.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(child.host_value)
|
||||
self.evictable_size_ += len(value[:prefix_len])
|
||||
return self._insert_helper(
|
||||
child, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
return prefix_len + self._insert_helper(
|
||||
child, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
|
||||
# partial match, split the node
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
if new_node.evicted:
|
||||
new_node.value = value[:prefix_len]
|
||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
||||
self.evictable_size_ += len(new_node.value)
|
||||
return self._insert_helper(
|
||||
new_node, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
else:
|
||||
self.inc_hit_count(new_node)
|
||||
return prefix_len + self._insert_helper(
|
||||
new_node, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
node.children[key[0]] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
|
||||
if self.cache_controller.write_policy == "write_through":
|
||||
self.write_backup(new_node)
|
||||
return 0
|
||||
|
||||
def _collect_leaves_device(self):
|
||||
def is_leaf(node):
|
||||
if node.evicted:
|
||||
return False
|
||||
if node == self.root_node:
|
||||
return False
|
||||
if len(node.children) == 0:
|
||||
return True
|
||||
for child in node.children.values():
|
||||
if not child.evicted:
|
||||
return False
|
||||
return True
|
||||
|
||||
ret_list = []
|
||||
stack = [self.root_node]
|
||||
while stack:
|
||||
cur_node = stack.pop()
|
||||
if is_leaf(cur_node):
|
||||
ret_list.append(cur_node)
|
||||
else:
|
||||
for cur_child in cur_node.children.values():
|
||||
if not cur_child.evicted:
|
||||
stack.append(cur_child)
|
||||
return ret_list
|
||||
@@ -442,7 +442,7 @@ class MLATokenToKVPoolHost:
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 2.0,
|
||||
host_to_device_ratio: float = 4.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
@@ -502,6 +502,9 @@ class MLATokenToKVPoolHost:
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
|
||||
@@ -1289,7 +1289,7 @@ def debug_timing(func):
|
||||
tic.record()
|
||||
result = func(*args, **kwargs)
|
||||
toc.record()
|
||||
torch.cuda.synchronize() # Ensure all CUDA operations are complete
|
||||
toc.synchronize() # Wait for the function to complete without synchronizing all ops on the GPU
|
||||
elapsed = tic.elapsed_time(toc)
|
||||
indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
|
||||
num_tokens = len(indices) if indices is not None else 0
|
||||
|
||||
Reference in New Issue
Block a user