[PD] Support decode overlap schedule (#5608)

This commit is contained in:
Byron Hsu
2025-04-21 12:06:16 -07:00
committed by GitHub
parent 4dce1cc608
commit e65b9f21e3
3 changed files with 49 additions and 5 deletions

View File

@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch
@torch.no_grad()
def event_loop_overlap_disagg_decode(self):
result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
last_batch_is_extend = False
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False)
last_batch_is_extend = True
else:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and not self.last_batch_is_extend:
tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
self.last_batch_is_extend = last_batch_is_extend
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]: