Fix several minor issues in PD disaggregation (#5444)

This commit is contained in:
Cheng Wan
2025-04-15 23:04:41 -07:00
committed by GitHub
parent 5b5c7237c8
commit 6aca583420
3 changed files with 67 additions and 69 deletions

View File

@@ -419,6 +419,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
class SchedulerDisaggregationDecodeMixin:
@torch.no_grad()
def event_loop_normal_disagg_decode(self):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False)
else:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:

View File

@@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin:
Mixin for Scheduler to handle disaggregation prefill
"""
@torch.no_grad()
def event_loop_normal_disagg_prefill(self):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result_disagg_prefill(batch, result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
) -> None:
@@ -210,7 +240,7 @@ class SchedulerDisaggregationPrefillMixin:
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.tp_worker.get_tp_cpu_group(),
self.attn_tp_cpu_group,
)
undone_reqs: List[Req] = []