Fix several minor issues in PD disaggregation (#5444)
This commit is contained in:
@@ -419,6 +419,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
|
||||
class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(batch.reqs, False)
|
||||
else:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
self: Scheduler,
|
||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||
|
||||
@@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
Mixin for Scheduler to handle disaggregation prefill
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_prefill(self):
|
||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result_disagg_prefill(batch, result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def process_batch_result_disagg_prefill(
|
||||
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
||||
) -> None:
|
||||
@@ -210,7 +240,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
||||
self.tp_worker.get_tp_cpu_group(),
|
||||
self.attn_tp_cpu_group,
|
||||
)
|
||||
|
||||
undone_reqs: List[Req] = []
|
||||
|
||||
Reference in New Issue
Block a user