feature(pd-hicache): Prefill instances support reusing the RemoteStorage Cache via HiCache. (#8516)
Co-authored-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -1185,23 +1185,28 @@ class Scheduler(
|
|||||||
def _add_request_to_queue(self, req: Req):
|
def _add_request_to_queue(self, req: Req):
|
||||||
req.queue_time_start = time.perf_counter()
|
req.queue_time_start = time.perf_counter()
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self._prefetch_kvcache(req)
|
||||||
self.disagg_prefill_bootstrap_queue.add(
|
self.disagg_prefill_bootstrap_queue.add(
|
||||||
req, self.model_config.num_key_value_heads
|
req, self.model_config.num_key_value_heads
|
||||||
)
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.disagg_decode_prealloc_queue.add(req)
|
self.disagg_decode_prealloc_queue.add(req)
|
||||||
else:
|
else:
|
||||||
if self.enable_hicache_storage:
|
self._prefetch_kvcache(req)
|
||||||
req.init_next_round_input(self.tree_cache)
|
|
||||||
last_hash = req.last_host_node.get_last_hash_value()
|
|
||||||
matched_len = len(req.prefix_indices) + req.host_hit_length
|
|
||||||
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
|
||||||
new_input_tokens = req.fill_ids[matched_len:]
|
|
||||||
self.tree_cache.prefetch_from_storage(
|
|
||||||
req.rid, req.last_host_node, new_input_tokens, last_hash
|
|
||||||
)
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
|
def _prefetch_kvcache(self, req: Req):
|
||||||
|
if self.enable_hicache_storage:
|
||||||
|
req.init_next_round_input(self.tree_cache)
|
||||||
|
last_hash = req.last_host_node.get_last_hash_value()
|
||||||
|
matched_len = len(req.prefix_indices) + req.host_hit_length
|
||||||
|
# todo, free-form fetching, calculating hash keys on the fly
|
||||||
|
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
||||||
|
new_input_tokens = req.fill_ids[matched_len:]
|
||||||
|
self.tree_cache.prefetch_from_storage(
|
||||||
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
||||||
|
)
|
||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.disagg_prefill_bootstrap_queue.extend(
|
self.disagg_prefill_bootstrap_queue.extend(
|
||||||
|
|||||||
Reference in New Issue
Block a user