[PD] Support prefill overlap + Ensure no race condition (#5609)

This commit is contained in:
Byron Hsu
2025-04-21 12:12:56 -07:00
committed by GitHub
parent e65b9f21e3
commit bf98d2e377
5 changed files with 107 additions and 18 deletions

View File

@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from __future__ import annotations
import logging
from collections import deque
from typing import TYPE_CHECKING, List, Optional
import torch
@@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin:
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_overlap_disagg_prefill(self):
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
) -> None:
@@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
Adapted from process_batch_result_prefill
"""
next_token_ids = result.next_token_ids.tolist()
(
logits_output,
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.bid,
)
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
else:
next_token_ids = result.next_token_ids.tolist()
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
req: Req
@@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
# TODO: Not sure if this is necessary
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
# We need to remove this for overlap schedule.
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
if self.enable_overlap:
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
"""
@@ -276,20 +326,37 @@ class SchedulerDisaggregationPrefillMixin:
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
self.send_kv_chunk(self.chunked_req)
if (
self.enable_overlap
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
)
else:
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
def send_kv_chunk(
self: Scheduler, req: Req, token_id: Optional[int] = None
self: Scheduler,
req: Req,
token_id: Optional[int] = None,
end_idx: Optional[int] = None,
) -> None:
"""
Send a prefilled chunk to the decode server
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx = (
end_idx
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
last_chunk = token_id is not None
if (not last_chunk) and (
@@ -302,7 +369,7 @@ class SchedulerDisaggregationPrefillMixin:
req.start_send_idx = end_idx
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)

View File

@@ -539,6 +539,11 @@ class Req:
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)

View File

@@ -2014,7 +2014,10 @@ def run_scheduler_process(
else:
scheduler.event_loop_normal()
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()

View File

@@ -388,8 +388,6 @@ class ServerArgs:
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for prefill server")
elif self.disaggregation_mode == "decode":
self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")