Hierarchical Caching Refactoring and Fixing TP issue (#4082)

This commit is contained in:
Zhiqiang Xie
2025-03-12 11:22:35 -07:00
committed by GitHub
parent 01090e8ac3
commit 10b544ae9b
6 changed files with 194 additions and 56 deletions

View File

@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()
def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()
def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()
def reset(self):
with self.condition:
self.counter = 0
class CacheOperation: class CacheOperation:
counter = 0 counter = 0
@@ -132,6 +152,7 @@ class HiCacheController:
self, self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost, mem_pool_host: MHATokenToKVPoolHost,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
@@ -139,6 +160,10 @@ class HiCacheController:
self.mem_pool_host = mem_pool_host self.mem_pool_host = mem_pool_host
self.write_policy = write_policy self.write_policy = write_policy
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
if write_policy not in [ if write_policy not in [
"write_through", "write_through",
"write_through_selective", "write_through_selective",
@@ -165,7 +190,7 @@ class HiCacheController:
target=self.write_thread_func_buffer, daemon=True target=self.write_thread_func_buffer, daemon=True
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
) )
self.write_thread.start() self.write_thread.start()
self.load_thread.start() self.load_thread.start()
@@ -186,7 +211,7 @@ class HiCacheController:
target=self.write_thread_func_buffer, daemon=True target=self.write_thread_func_buffer, daemon=True
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
) )
self.stop_event.clear() self.stop_event.clear()
self.write_thread.start() self.write_thread.start()
@@ -273,6 +298,42 @@ class HiCacheController:
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def load_thread_func_layer_by_layer(self):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with torch.cuda.stream(self.load_stream):
while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()
batch_operation = None
while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices)
for node_id in batch_operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
def write_aux_func(self, no_wait=False): def write_aux_func(self, no_wait=False):
""" """
Auxiliary function to prepare the buffer for write operations. Auxiliary function to prepare the buffer for write operations.

View File

@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch # The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
self.last_node = None self.last_node = None
self.last_node_global = None
# Whether or not if it is chunked. It increments whenever # Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is # it is chunked, and decrement whenever chunked request is
@@ -389,13 +390,24 @@ class Req:
# Whether request reached finished condition # Whether request reached finished condition
return self.finished_reason is not None return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None: if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache. # tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix( if enable_hierarchical_cache:
rid=self.rid, key=self.adjust_max_prefix_ids() self.prefix_indices, self.last_node, self.last_node_global = (
) tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(), include_evicted=True
)
)
else:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):

View File

@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class SchedulePolicy: class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching. # It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache( self.waiting_queue_radix_tree = RadixCache(
@@ -149,9 +155,14 @@ class SchedulePolicy:
prefix_ids = r.adjust_max_prefix_ids() prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix( if self.enable_hierarchical_cache:
rid=r.rid, key=prefix_ids r.prefix_indices, r.last_node, r.last_node_global = (
) self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching; # NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from # If there are more than 1 request that have small matching prefix from
@@ -428,7 +439,9 @@ class PrefillAdder:
return self.budget_state() return self.budget_state()
def add_one_req(self, req: Req, has_chunked_req: bool): def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable: if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
@@ -448,6 +461,18 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens: if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
if (
enable_hierarchical_cache
and req.last_node_global is not None
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
)
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)

View File

@@ -265,12 +265,10 @@ class Scheduler:
f"context_len={self.model_config.context_len}" f"context_len={self.model_config.context_len}"
) )
# Init memory pool and cache
self.init_memory_pool_and_cache() self.init_memory_pool_and_cache()
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.staging_reqs = {}
# The running decoding batch for continuous batching # The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch # The current forward batch
@@ -308,7 +306,9 @@ class Scheduler:
self.grammar_backend = None self.grammar_backend = None
# Init schedule policy and new token estimation # Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
)
assert ( assert (
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness" ), "Invalid schedule_conservativeness"
@@ -431,6 +431,7 @@ class Scheduler:
self.tree_cache = HiRadixCache( self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
@@ -1005,6 +1006,11 @@ class Scheduler:
self.batch_is_full = True self.batch_is_full = True
return None return None
if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue # Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue) prefix_computed = self.policy.calc_priority(self.waiting_queue)
@@ -1048,32 +1054,14 @@ class Scheduler:
self.batch_is_full = True self.batch_is_full = True
break break
req.init_next_round_input(None if prefix_computed else self.tree_cache) req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,
)
if self.enable_hierarchical_cache and req.last_node is not None: res = adder.add_one_req(
if req.last_node.evicted: req, self.chunked_req, self.enable_hierarchical_cache
# loading KV cache for the request )
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node,
req.prefix_indices,
adder.rem_total_tokens,
)
if req.last_node.loading:
# to prevent frequent cache invalidation
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
self.tree_cache.inc_lock_ref(req.last_node)
self.staging_reqs[req.rid] = req.last_node
continue
elif req.last_node.loading:
if not self.tree_cache.loading_complete(req.last_node):
continue
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid]
res = adder.add_one_req(req, self.chunked_req)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
@@ -1094,6 +1082,9 @@ class Scheduler:
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()
if adder.new_chunked_req is not None: if adder.new_chunked_req is not None:
assert self.chunked_req is None assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req self.chunked_req = adder.new_chunked_req

View File

@@ -1,11 +1,13 @@
import heapq import heapq
import logging import logging
import threading
import time import time
from typing import List, Optional from typing import List, Optional
import torch import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
ReqToTokenPool, ReqToTokenPool,
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
): ):
self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache() token_to_kv_pool_allocator.get_kvcache()
) )
self.tp_group = tp_cache_group
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
token_to_kv_pool_allocator, self.token_to_kv_pool_host token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
load_cache_event=self.load_cache_event,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def write_backup(self, node: TreeNode): def write_backup(self, node: TreeNode):
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=node.value, device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id, node_id=node.id,
) )
if host_indices is None: if host_indices is None:
self.evict_host(len(node.value)) self.evict_host(len(node.value))
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=node.value, device_indices=node.value,
priority=-self.get_height(node),
node_id=node.id, node_id=node.id,
) )
if host_indices is not None: if host_indices is not None:
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node.hit_count = 0 node.hit_count = 0
def writing_check(self): def writing_check(self):
while not self.cache_controller.ack_write_queue.empty(): queue_size = torch.tensor(
try: self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
ack_id = self.cache_controller.ack_write_queue.get_nowait() )
self.dec_lock_ref(self.ongoing_write_through[ack_id]) if torch.distributed.get_world_size(group=self.tp_group) > 1:
# clear the reference # synchrnoize TP workers to make the same update to radix cache
del self.ongoing_write_through[ack_id] torch.distributed.all_reduce(
except Exception: queue_size,
break op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
del self.ongoing_write_through[ack_id]
def loading_check(self): def loading_check(self):
while not self.cache_controller.ack_load_queue.empty(): while not self.cache_controller.ack_load_queue.empty():
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break break
def evictable_size(self): def evictable_size(self):
self.writing_check()
self.loading_check()
return self.evictable_size_ return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None): def evict(self, num_tokens: int, evict_callback=None):
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return device_indices return device_indices
def loading_complete(self, node: TreeNode):
self.loading_check()
return node.loading == False
def init_load_back( def init_load_back(
self, self,
last_node: TreeNode, last_node: TreeNode,
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices return last_node, prefix_indices
def read_to_load_cache(self):
self.load_cache_event.set()
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
if self.disable:
return [], self.root_node
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
last_node_global = last_node
while last_node.evicted:
last_node = last_node.parent
if include_evicted:
return value, last_node, last_node_global
else:
return value, last_node
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time() node.last_access_time = time.time()
value = [] value = []

View File

@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self.layer_num = layer_num self.layer_num = layer_num
self._create_buffers() self._create_buffers()
self.layer_transfer_counter = None
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
logger.info( logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[i][indices] = k_data[i] self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i] self.v_buffer[i][indices] = v_data[i]
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id] return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id] return self.v_buffer[layer_id]
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices): def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices] return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data self.kv_buffer[:, :, indices] = flat_data