[HiCache] resolve conflict between chunked-prefill and hicache hit count (#9776)
This commit is contained in:
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
# Move the chunked request out of the batch so that we can merge
|
# Move the chunked request out of the batch so that we can merge
|
||||||
# only finished requests to running_batch.
|
# only finished requests to running_batch.
|
||||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
||||||
self.chunked_req.tmp_end_idx = min(
|
self.chunked_req.tmp_end_idx = min(
|
||||||
|
|||||||
@@ -1503,7 +1503,7 @@ class Scheduler(
|
|||||||
# Move the chunked request out of the batch so that we can merge
|
# Move the chunked request out of the batch so that we can merge
|
||||||
# only finished requests to running_batch.
|
# only finished requests to running_batch.
|
||||||
chunked_req_to_exclude.add(self.chunked_req)
|
chunked_req_to_exclude.add(self.chunked_req)
|
||||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
||||||
# chunked request keeps its rid but will get a new req_pool_idx
|
# chunked request keeps its rid but will get a new req_pool_idx
|
||||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req):
|
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(req.fill_ids)
|
req.req_pool_idx, : len(req.fill_ids)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class HiRadixCache(RadixCache):
|
|||||||
self.ongoing_backup = {}
|
self.ongoing_backup = {}
|
||||||
# todo: dynamically adjust the threshold
|
# todo: dynamically adjust the threshold
|
||||||
self.write_through_threshold = (
|
self.write_through_threshold = (
|
||||||
1 if hicache_write_policy == "write_through" else 3
|
1 if hicache_write_policy == "write_through" else 2
|
||||||
)
|
)
|
||||||
self.write_through_threshold_storage = (
|
self.write_through_threshold_storage = (
|
||||||
1 if hicache_write_policy == "write_through" else 3
|
1 if hicache_write_policy == "write_through" else 3
|
||||||
@@ -155,8 +155,9 @@ class HiRadixCache(RadixCache):
|
|||||||
self.ongoing_backup[operation_id] = node
|
self.ongoing_backup[operation_id] = node
|
||||||
node.protect_host()
|
node.protect_host()
|
||||||
|
|
||||||
def inc_hit_count(self, node: TreeNode):
|
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
||||||
if self.cache_controller.write_policy == "write_back":
|
# skip the hit count update for chunked requests
|
||||||
|
if self.cache_controller.write_policy == "write_back" or chunked:
|
||||||
return
|
return
|
||||||
node.hit_count += 1
|
node.hit_count += 1
|
||||||
|
|
||||||
@@ -672,11 +673,11 @@ class HiRadixCache(RadixCache):
|
|||||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
def insert(self, key: List, value, chunked=False):
|
||||||
node.last_access_time = time.monotonic()
|
|
||||||
if len(key) == 0:
|
if len(key) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
node = self.root_node
|
||||||
child_key = self.get_child_key_fn(key)
|
child_key = self.get_child_key_fn(key)
|
||||||
total_prefix_length = 0
|
total_prefix_length = 0
|
||||||
|
|
||||||
@@ -693,7 +694,7 @@ class HiRadixCache(RadixCache):
|
|||||||
self.token_to_kv_pool_host.update_synced(node.host_value)
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
||||||
self.evictable_size_ += len(node.value)
|
self.evictable_size_ += len(node.value)
|
||||||
else:
|
else:
|
||||||
self.inc_hit_count(node)
|
self._inc_hit_count(node, chunked)
|
||||||
total_prefix_length += prefix_len
|
total_prefix_length += prefix_len
|
||||||
else:
|
else:
|
||||||
# partial match, split the node
|
# partial match, split the node
|
||||||
@@ -703,7 +704,7 @@ class HiRadixCache(RadixCache):
|
|||||||
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
||||||
self.evictable_size_ += len(new_node.value)
|
self.evictable_size_ += len(new_node.value)
|
||||||
else:
|
else:
|
||||||
self.inc_hit_count(new_node)
|
self._inc_hit_count(new_node, chunked)
|
||||||
total_prefix_length += prefix_len
|
total_prefix_length += prefix_len
|
||||||
node = new_node
|
node = new_node
|
||||||
|
|
||||||
@@ -737,7 +738,7 @@ class HiRadixCache(RadixCache):
|
|||||||
last_hash = new_node.hash_value[-1]
|
last_hash = new_node.hash_value[-1]
|
||||||
|
|
||||||
if self.cache_controller.write_policy != "write_back":
|
if self.cache_controller.write_policy != "write_back":
|
||||||
self.inc_hit_count(new_node)
|
self._inc_hit_count(new_node, chunked)
|
||||||
return total_prefix_length
|
return total_prefix_length
|
||||||
|
|
||||||
def _collect_leaves_device(self):
|
def _collect_leaves_device(self):
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.dec_lock_ref(req.last_node)
|
self.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req):
|
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||||
"""Cache request when it is unfinished."""
|
"""Cache request when it is unfinished."""
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
last_host_node=last_node,
|
last_host_node=last_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert(self, key: List, value=None):
|
def insert(self, key: List, value=None, chunked=False):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.dec_lock_ref(req.last_node)
|
self.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req):
|
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||||
"""Cache request when it is unfinished."""
|
"""Cache request when it is unfinished."""
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
@@ -261,7 +261,9 @@ class RadixCache(BasePrefixCache):
|
|||||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
|
new_prefix_len = self.insert(
|
||||||
|
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
||||||
|
)
|
||||||
self.token_to_kv_pool_allocator.free(
|
self.token_to_kv_pool_allocator.free(
|
||||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|||||||
self.dec_lock_ref(req.last_node)
|
self.dec_lock_ref(req.last_node)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req):
|
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||||
"""Cache request when it is unfinished."""
|
"""Cache request when it is unfinished."""
|
||||||
assert req.req_pool_idx is not None
|
assert req.req_pool_idx is not None
|
||||||
token_ids = req.fill_ids
|
token_ids = req.fill_ids
|
||||||
|
|||||||
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req) -> None:
|
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
||||||
"""Cache request when it is unfinished."""
|
"""Cache request when it is unfinished."""
|
||||||
if self.disable:
|
if self.disable:
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
|||||||
Reference in New Issue
Block a user