[PD] Support prefill overlap + Ensure no race condition (#5609)
This commit is contained in:
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_prefill(self):
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def process_batch_result_disagg_prefill(
|
||||
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
||||
) -> None:
|
||||
@@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req,
|
||||
bid,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.extend_input_len_per_req,
|
||||
result.extend_logprob_start_len_per_req,
|
||||
result.bid,
|
||||
)
|
||||
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
||||
req: Req
|
||||
@@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_chunked -= 1
|
||||
|
||||
# TODO: Not sure if this is necessary
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
# We need to remove this for overlap schedule.
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
if self.enable_overlap:
|
||||
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
||||
|
||||
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
||||
"""
|
||||
@@ -276,20 +326,37 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# only finished requests to running_batch.
|
||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
self.send_kv_chunk(self.chunked_req)
|
||||
if (
|
||||
self.enable_overlap
|
||||
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
||||
self.chunked_req.tmp_end_idx = min(
|
||||
len(self.chunked_req.fill_ids),
|
||||
len(self.chunked_req.origin_input_ids),
|
||||
)
|
||||
else:
|
||||
self.send_kv_chunk(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def send_kv_chunk(
|
||||
self: Scheduler, req: Req, token_id: Optional[int] = None
|
||||
self: Scheduler,
|
||||
req: Req,
|
||||
token_id: Optional[int] = None,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a prefilled chunk to the decode server
|
||||
"""
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
start_idx = req.start_send_idx
|
||||
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
||||
# the resolved length is not the same as fill_ids's length
|
||||
end_idx = (
|
||||
end_idx
|
||||
if end_idx is not None
|
||||
else min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
)
|
||||
last_chunk = token_id is not None
|
||||
|
||||
if (not last_chunk) and (
|
||||
@@ -302,7 +369,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
req.start_send_idx = end_idx
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
@@ -539,6 +539,11 @@ class Req:
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
||||
# This is because kv is not ready in `process_prefill_chunk`.
|
||||
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
||||
self.tmp_end_idx: int = -1
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
@@ -2014,7 +2014,10 @@ def run_scheduler_process(
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_prefill()
|
||||
else:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_decode()
|
||||
|
||||
@@ -388,8 +388,6 @@ class ServerArgs:
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
|
||||
Reference in New Issue
Block a user