upstream hicache fixes (#5570)

This commit is contained in:
Zhiqiang Xie
2025-04-20 23:08:30 -07:00
committed by GitHub
parent 188f0955fa
commit 70645f4d7d
8 changed files with 89 additions and 46 deletions

View File

@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
hicache_ratio: float,
hicache_size: int,
hicache_write_policy: str,
):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio, page_size
self.kv_cache, hicache_ratio, hicache_size, page_size
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio, page_size
self.kv_cache, hicache_ratio, hicache_size, page_size
)
else:
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host,
page_size,
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
)
# record the nodes with ongoing write through
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
# record the node segments with ongoing load back
self.ongoing_load_back = {}
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3
)
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
height += 1
return height
def write_backup(self, node: TreeNode):
def write_backup(self, node: TreeNode, write_back=False):
host_indices = self.cache_controller.write(
device_indices=node.value,
node_id=node.id,
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
if host_indices is not None:
node.host_value = host_indices
self.ongoing_write_through[node.id] = node
self.inc_lock_ref(node)
if not write_back:
# no need to lock nodes if write back
self.inc_lock_ref(node)
else:
return 0
return len(host_indices)
def inc_hit_count(self, node: TreeNode):
if self.cache_controller.write_policy != "write_through_selective":
if node.backuped or self.cache_controller.write_policy == "write_back":
return
node.hit_count += 1
if node.host_value is None and node.hit_count > self.write_through_threshold:
if node.hit_count >= self.write_through_threshold:
self.write_backup(node)
node.hit_count = 0
def writing_check(self):
def writing_check(self, write_back=False):
if write_back:
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get()
del self.ongoing_write_through[ack_id]
return
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
heapq.heapify(leaves)
num_evicted = 0
pending_nodes = []
write_back_nodes = []
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x.lock_ref > 0:
continue
if x.host_value is None:
if not x.backuped:
if self.cache_controller.write_policy == "write_back":
num_evicted += self.write_backup(x)
pending_nodes.append(x)
elif self.cache_controller.write_policy == "write_through_selective":
num_evicted += self._evict_write_through_selective(x)
# write to host if the node is not backuped
num_evicted += self.write_backup(x, write_back=True)
write_back_nodes.append(x)
else:
assert (
self.cache_controller.write_policy != "write_through"
), "write_through should be inclusive"
raise NotImplementedError
num_evicted += self._evict_regular(x)
else:
num_evicted += self._evict_write_through(x)
num_evicted += self._evict_backuped(x)
for child in x.parent.children.values():
if child in pending_nodes:
if child in write_back_nodes:
continue
if not child.evicted:
break
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
heapq.heappush(leaves, x.parent)
if self.cache_controller.write_policy == "write_back":
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
self.writing_check()
time.sleep(0.1)
for node in pending_nodes:
assert node.host_value is not None
self._evict_write_through(node)
self.writing_check(write_back=True)
for node in write_back_nodes:
assert node.backuped
self._evict_backuped(node)
def _evict_write_through(self, node: TreeNode):
def _evict_backuped(self, node: TreeNode):
# evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
assert num_evicted > 0
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
node.value = None
return num_evicted
def _evict_write_through_selective(self, node: TreeNode):
def _evict_regular(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
node = child
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
else:
new_node.value = child.value[:split_len]
child.value = child.value[split_len:]
if child.host_value is not None:
if child.backuped:
new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:]
child.parent = new_node
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
node.children[child_key] = new_node
self.evictable_size_ += len(value)
if self.cache_controller.write_policy == "write_through":
self.write_backup(new_node)
if self.cache_controller.write_policy != "write_back":
self.inc_hit_count(new_node)
return total_prefix_length
def _collect_leaves_device(self):