Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayerDoneCounter:
|
||||
def __init__(self, num_layers):
|
||||
self.counter = num_layers
|
||||
self.condition = threading.Condition()
|
||||
|
||||
def increment(self):
|
||||
with self.condition:
|
||||
self.counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
def wait_until(self, threshold):
|
||||
with self.condition:
|
||||
while self.counter <= threshold:
|
||||
self.condition.wait()
|
||||
|
||||
def reset(self):
|
||||
with self.condition:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class CacheOperation:
|
||||
|
||||
counter = 0
|
||||
@@ -132,6 +152,7 @@ class HiCacheController:
|
||||
self,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
load_cache_event: threading.Event = None,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
@@ -139,6 +160,10 @@ class HiCacheController:
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||
|
||||
if write_policy not in [
|
||||
"write_through",
|
||||
"write_through_selective",
|
||||
@@ -165,7 +190,7 @@ class HiCacheController:
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_buffer, daemon=True
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
)
|
||||
self.write_thread.start()
|
||||
self.load_thread.start()
|
||||
@@ -186,7 +211,7 @@ class HiCacheController:
|
||||
target=self.write_thread_func_buffer, daemon=True
|
||||
)
|
||||
self.load_thread = threading.Thread(
|
||||
target=self.load_thread_func_buffer, daemon=True
|
||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||
)
|
||||
self.stop_event.clear()
|
||||
self.write_thread.start()
|
||||
@@ -273,6 +298,42 @@ class HiCacheController:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def load_thread_func_layer_by_layer(self):
|
||||
"""
|
||||
Load KV caches from host memory to device memory layer by layer.
|
||||
"""
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
while not self.stop_event.is_set():
|
||||
self.load_cache_event.wait(timeout=1)
|
||||
if not self.load_cache_event.is_set():
|
||||
continue
|
||||
self.load_cache_event.clear()
|
||||
|
||||
batch_operation = None
|
||||
while self.load_queue.qsize() > 0:
|
||||
op = self.load_queue.get(block=True)
|
||||
if batch_operation is None:
|
||||
batch_operation = op
|
||||
else:
|
||||
batch_operation.merge(op)
|
||||
if batch_operation is None:
|
||||
continue
|
||||
|
||||
self.layer_done_counter.reset()
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
||||
batch_operation.host_indices, i
|
||||
)
|
||||
self.mem_pool_device.transfer_per_layer(
|
||||
batch_operation.device_indices, flat_data, i
|
||||
)
|
||||
self.layer_done_counter.increment()
|
||||
|
||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||
for node_id in batch_operation.node_ids:
|
||||
if node_id != 0:
|
||||
self.ack_load_queue.put(node_id)
|
||||
|
||||
def write_aux_func(self, no_wait=False):
|
||||
"""
|
||||
Auxiliary function to prepare the buffer for write operations.
|
||||
|
||||
@@ -315,6 +315,7 @@ class Req:
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
self.last_node = None
|
||||
self.last_node_global = None
|
||||
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
@@ -389,13 +390,24 @@ class Req:
|
||||
# Whether request reached finished condition
|
||||
return self.finished_reason is not None
|
||||
|
||||
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
||||
def init_next_round_input(
|
||||
self,
|
||||
tree_cache: Optional[BasePrefixCache] = None,
|
||||
enable_hierarchical_cache=False,
|
||||
):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
if tree_cache is not None:
|
||||
# tree cache is None if the prefix is not computed with tree cache.
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
if enable_hierarchical_cache:
|
||||
self.prefix_indices, self.last_node, self.last_node_global = (
|
||||
tree_cache.match_prefix(
|
||||
key=self.adjust_max_prefix_ids(), include_evicted=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
|
||||
class SchedulePolicy:
|
||||
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
|
||||
|
||||
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
policy: str,
|
||||
tree_cache: BasePrefixCache,
|
||||
enable_hierarchical_cache: bool = False,
|
||||
):
|
||||
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
||||
self.tree_cache = tree_cache
|
||||
self.enable_hierarchical_cache = enable_hierarchical_cache
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
@@ -149,9 +155,14 @@ class SchedulePolicy:
|
||||
prefix_ids = r.adjust_max_prefix_ids()
|
||||
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
r.prefix_indices, r.last_node, r.last_node_global = (
|
||||
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
|
||||
)
|
||||
else:
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
|
||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||
# If there are more than 1 request that have small matching prefix from
|
||||
@@ -428,7 +439,9 @@ class PrefillAdder:
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||
def add_one_req(
|
||||
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
||||
):
|
||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
@@ -448,6 +461,18 @@ class PrefillAdder:
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
enable_hierarchical_cache
|
||||
and req.last_node_global is not None
|
||||
and req.last_node_global.evicted
|
||||
):
|
||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
||||
req.last_node_global, req.prefix_indices
|
||||
)
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
input_tokens = req.extend_input_len
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
|
||||
@@ -265,12 +265,10 @@ class Scheduler:
|
||||
f"context_len={self.model_config.context_len}"
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.init_memory_pool_and_cache()
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.staging_reqs = {}
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
@@ -308,7 +306,9 @@ class Scheduler:
|
||||
self.grammar_backend = None
|
||||
|
||||
# Init schedule policy and new token estimation
|
||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||
self.policy = SchedulePolicy(
|
||||
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
|
||||
)
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
), "Invalid schedule_conservativeness"
|
||||
@@ -431,6 +431,7 @@ class Scheduler:
|
||||
self.tree_cache = HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
@@ -1005,6 +1006,11 @@ class Scheduler:
|
||||
self.batch_is_full = True
|
||||
return None
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
# check for completion of hierarchical cache activities to release memory
|
||||
self.tree_cache.writing_check()
|
||||
self.tree_cache.loading_check()
|
||||
|
||||
# Get priority queue
|
||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
||||
|
||||
@@ -1048,32 +1054,14 @@ class Scheduler:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
|
||||
if self.enable_hierarchical_cache and req.last_node is not None:
|
||||
if req.last_node.evicted:
|
||||
# loading KV cache for the request
|
||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
||||
req.last_node,
|
||||
req.prefix_indices,
|
||||
adder.rem_total_tokens,
|
||||
)
|
||||
if req.last_node.loading:
|
||||
# to prevent frequent cache invalidation
|
||||
if req.rid in self.staging_reqs:
|
||||
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self.staging_reqs[req.rid] = req.last_node
|
||||
continue
|
||||
elif req.last_node.loading:
|
||||
if not self.tree_cache.loading_complete(req.last_node):
|
||||
continue
|
||||
|
||||
if req.rid in self.staging_reqs:
|
||||
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
||||
del self.staging_reqs[req.rid]
|
||||
|
||||
res = adder.add_one_req(req, self.chunked_req)
|
||||
res = adder.add_one_req(
|
||||
req, self.chunked_req, self.enable_hierarchical_cache
|
||||
)
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
if self.enable_hierarchical_cache:
|
||||
@@ -1094,6 +1082,9 @@ class Scheduler:
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache.read_to_load_cache()
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
self.chunked_req = adder.new_chunked_req
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.tp_group = tp_cache_group
|
||||
|
||||
self.load_cache_event = threading.Event()
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator, self.token_to_kv_pool_host
|
||||
token_to_kv_pool_allocator,
|
||||
self.token_to_kv_pool_host,
|
||||
load_cache_event=self.load_cache_event,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
|
||||
def write_backup(self, node: TreeNode):
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is None:
|
||||
self.evict_host(len(node.value))
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
priority=-self.get_height(node),
|
||||
node_id=node.id,
|
||||
)
|
||||
if host_indices is not None:
|
||||
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
|
||||
node.hit_count = 0
|
||||
|
||||
def writing_check(self):
|
||||
while not self.cache_controller.ack_write_queue.empty():
|
||||
try:
|
||||
ack_id = self.cache_controller.ack_write_queue.get_nowait()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
# clear the reference
|
||||
del self.ongoing_write_through[ack_id]
|
||||
except Exception:
|
||||
break
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
if torch.distributed.get_world_size(group=self.tp_group) > 1:
|
||||
# synchrnoize TP workers to make the same update to radix cache
|
||||
torch.distributed.all_reduce(
|
||||
queue_size,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
)
|
||||
for _ in range(queue_size.item()):
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
self.dec_lock_ref(self.ongoing_write_through[ack_id])
|
||||
del self.ongoing_write_through[ack_id]
|
||||
|
||||
def loading_check(self):
|
||||
while not self.cache_controller.ack_load_queue.empty():
|
||||
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
|
||||
break
|
||||
|
||||
def evictable_size(self):
|
||||
self.writing_check()
|
||||
self.loading_check()
|
||||
return self.evictable_size_
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback=None):
|
||||
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return device_indices
|
||||
|
||||
def loading_complete(self, node: TreeNode):
|
||||
self.loading_check()
|
||||
return node.loading == False
|
||||
|
||||
def init_load_back(
|
||||
self,
|
||||
last_node: TreeNode,
|
||||
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def read_to_load_cache(self):
|
||||
self.load_cache_event.set()
|
||||
|
||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
|
||||
last_node_global = last_node
|
||||
while last_node.evicted:
|
||||
last_node = last_node.parent
|
||||
|
||||
if include_evicted:
|
||||
return value, last_node, last_node_global
|
||||
else:
|
||||
return value, last_node
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
value = []
|
||||
|
||||
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
||||
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
self.k_buffer[layer_id][indices] = k_data
|
||||
self.v_buffer[layer_id][indices] = v_data
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id].view(self.dtype)
|
||||
return self.v_buffer[layer_id]
|
||||
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user