Integrating PD disaggregation with DP attention and DeepEP (#5435)
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
@@ -444,6 +444,15 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
|
||||
class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.process_batch_result(batch, result)
|
||||
return batch, result
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
@@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
prepare_dp_attn_flag = (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(batch.reqs, False)
|
||||
if prepare_dp_attn_flag:
|
||||
self._prepare_idle_batch_and_run(None)
|
||||
else:
|
||||
if prepare_dp_attn_flag:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
elif prepare_dp_attn_flag:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
@@ -480,7 +500,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
def event_loop_overlap_disagg_decode(self):
|
||||
result_queue = deque()
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
||||
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
@@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
last_batch_is_extend = False
|
||||
last_batch_in_queue = False
|
||||
|
||||
prepare_dp_attn_flag = (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(batch.reqs, False)
|
||||
last_batch_is_extend = True
|
||||
if prepare_dp_attn_flag:
|
||||
batch_, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch_:
|
||||
result_queue.append((batch_.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
else:
|
||||
if prepare_dp_attn_flag:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
elif prepare_dp_attn_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch:
|
||||
result_queue.append((batch.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and not self.last_batch_is_extend:
|
||||
if self.last_batch and self.last_batch_in_queue:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
@@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
self.last_batch_is_extend = last_batch_is_extend
|
||||
self.last_batch_in_queue = last_batch_in_queue
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
self: Scheduler,
|
||||
|
||||
@@ -187,6 +187,14 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
# Handle DP attention
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -217,6 +225,14 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
# Handle DP attention
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
|
||||
@@ -23,11 +23,13 @@ import psutil
|
||||
import setproctitle
|
||||
import zmq
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -226,9 +228,14 @@ class DataParallelController:
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
def round_robin_scheduler(self, req: Req):
|
||||
if self.server_args.disaggregation_mode == "null":
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
else:
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user