[PD] Support decode overlap schedule (#5608)
This commit is contained in:
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def event_loop_overlap_disagg_decode(self):
|
||||||
|
result_queue = deque()
|
||||||
|
self.last_batch: Optional[ScheduleBatch] = None
|
||||||
|
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
||||||
|
|
||||||
|
while True:
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
# polling and allocating kv cache
|
||||||
|
self.process_decode_queue()
|
||||||
|
batch = self.get_next_disagg_decode_batch_to_run()
|
||||||
|
self.cur_batch = batch
|
||||||
|
last_batch_is_extend = False
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
# Generate fake extend output.
|
||||||
|
if batch.forward_mode.is_extend():
|
||||||
|
# Note: Logprobs should be handled on the prefill engine.
|
||||||
|
self.stream_output(batch.reqs, False)
|
||||||
|
last_batch_is_extend = True
|
||||||
|
else:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
|
# Process the results of the previous batch but skip if the last batch is extend
|
||||||
|
if self.last_batch and not self.last_batch_is_extend:
|
||||||
|
tmp_batch, tmp_result = result_queue.popleft()
|
||||||
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
|
||||||
|
if batch is None and (
|
||||||
|
len(self.disagg_decode_transfer_queue.queue)
|
||||||
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
# When the server is idle, do self-check and re-init some states
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
|
self.last_batch = batch
|
||||||
|
self.last_batch_is_extend = last_batch_is_extend
|
||||||
|
|
||||||
def get_next_disagg_decode_batch_to_run(
|
def get_next_disagg_decode_batch_to_run(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||||
|
|||||||
@@ -2016,7 +2016,10 @@ def run_scheduler_process(
|
|||||||
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
scheduler.event_loop_normal_disagg_prefill()
|
scheduler.event_loop_normal_disagg_prefill()
|
||||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
scheduler.event_loop_normal_disagg_decode()
|
if scheduler.enable_overlap:
|
||||||
|
scheduler.event_loop_overlap_disagg_decode()
|
||||||
|
else:
|
||||||
|
scheduler.event_loop_normal_disagg_decode()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
|
|||||||
@@ -387,14 +387,12 @@ class ServerArgs:
|
|||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
if self.disaggregation_mode == "prefill":
|
if self.disaggregation_mode == "prefill":
|
||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
logger.warning("Cuda graph is disabled for prefill server")
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
logger.warning("Overlap scheduler is disabled for prefill server")
|
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||||
elif self.disaggregation_mode == "decode":
|
elif self.disaggregation_mode == "decode":
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
logger.warning("Cuda graph is disabled for prefill server")
|
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||||
self.disable_overlap_schedule = True
|
|
||||||
logger.warning("Overlap scheduler is disabled for decode server")
|
|
||||||
|
|
||||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||||
"1" if self.enable_torch_compile else "0"
|
"1" if self.enable_torch_compile else "0"
|
||||||
|
|||||||
Reference in New Issue
Block a user