[PD] Support decode overlap schedule (#5608)
This commit is contained in:
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_decode(self):
|
||||
result_queue = deque()
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
last_batch_is_extend = False
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(batch.reqs, False)
|
||||
last_batch_is_extend = True
|
||||
else:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and not self.last_batch_is_extend:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
self.last_batch_is_extend = last_batch_is_extend
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
self: Scheduler,
|
||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||
|
||||
@@ -2016,7 +2016,10 @@ def run_scheduler_process(
|
||||
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
scheduler.event_loop_normal_disagg_decode()
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_decode()
|
||||
else:
|
||||
scheduler.event_loop_normal_disagg_decode()
|
||||
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
|
||||
@@ -387,14 +387,12 @@ class ServerArgs:
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for decode server")
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
|
||||
Reference in New Issue
Block a user