Integrating PD disaggregation with DP attention and DeepEP (#5435)
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
@@ -23,11 +23,13 @@ import psutil
|
||||
import setproctitle
|
||||
import zmq
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -226,9 +228,14 @@ class DataParallelController:
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
def round_robin_scheduler(self, req: Req):
|
||||
if self.server_args.disaggregation_mode == "null":
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
else:
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user