diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index b23cb2d68..e7502d0c4 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver): mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - data_parallel_rank: Optional[int] = None, + prefill_dp_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr - self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( @@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver): self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - if self.data_parallel_rank is not None: - logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") - self.target_dp_group = self.data_parallel_rank + if prefill_dp_rank is not None: + logger.debug(f"Targeting DP rank: {prefill_dp_rank}") + self.prefill_dp_rank = prefill_dp_rank else: - self.target_dp_group = bootstrap_room % self.prefill_dp_size + self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size + + # FIXME: alias here: target_dp_group -> prefill_dp_rank + self.target_dp_group = self.prefill_dp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 528719f28..b79c8ca87 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -250,7 +250,7 @@ class DecodePreallocQueue: mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, - data_parallel_rank=req.data_parallel_rank, + prefill_dp_rank=req.data_parallel_rank, ) self.queue.append( diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index d25f47a38..120633824 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver): mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - data_parallel_rank: Optional[int] = None, + prefill_dp_rank: Optional[int] = None, ): self.has_init = False diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index c744e110d..0ad7280f9 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver): mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - data_parallel_rank: Optional[int] = None, + prefill_dp_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr @@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver): self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) self.conclude_state = None self.init_time = None - self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: ( @@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver): self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size ) * (self.prefill_pp_size // self.kv_mgr.pp_size) - if self.data_parallel_rank is not None: - logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") - self.target_dp_group = self.data_parallel_rank + if prefill_dp_rank is not None: + logger.debug(f"Targeting DP rank: {prefill_dp_rank}") + self.prefill_dp_rank = prefill_dp_rank else: - self.target_dp_group = bootstrap_room % self.prefill_dp_size + self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size + + # FIXME: alias here: target_dp_group -> prefill_dp_rank + self.target_dp_group = self.prefill_dp_rank self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 7a75d79b7..1b427ee61 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver): mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - data_parallel_rank: Optional[int] = None, + prefill_dp_rank: Optional[int] = None, ): self.started_transfer = False self.conclude_state = None - super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) + super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 677712a57..a7bb6d13a 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -106,7 +106,7 @@ class DataParallelController: # Launch data parallel workers self.scheduler_procs = [] - self.workers = [None] * server_args.dp_size + self.workers: List[zmq.Socket] = [None] * server_args.dp_size if server_args.enable_dp_attention: dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) @@ -272,27 +272,34 @@ class DataParallelController: self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] + def maybe_external_dp_rank_routing(self, req: Req): + if req.data_parallel_rank is not None: + logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") + self.workers[req.data_parallel_rank].send_pyobj(req) + return True + return False + def round_robin_scheduler(self, req: Req): + if self.maybe_external_dp_rank_routing(req): + return + if self.server_args.disaggregation_mode == "null": - if req.data_parallel_rank is not None: - logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") - self.workers[req.data_parallel_rank].send_pyobj(req) - else: - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len( - self.workers - ) + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) else: - if req.data_parallel_rank is not None: - logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") - self.workers[req.data_parallel_rank].send_pyobj(req) - else: - self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) + self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) def shortest_queue_scheduler(self, input_requests): + if self.maybe_external_dp_rank_routing(req): + return raise NotImplementedError() def minimum_tokens_scheduler(self, req): + if self.maybe_external_dp_rank_routing(req): + return + # This variable corresponds to the balance_id in TokenizedGenerateReqInput. # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). def get_next_global_balance_id() -> int: diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 8274003ad..4ab2e6a6f 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager): server_args: ServerArgs, port_args: PortArgs, ): - setproctitle.setproctitle( - f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}" - ) + setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}") # prevent init prefill bootstrapserver again disaggregation_mode = server_args.disaggregation_mode server_args.disaggregation_mode = "null" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index efe690750..22d344cc6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -44,6 +44,7 @@ from sglang.srt.utils import ( is_valid_ipv6_address, nullable_str, ) +from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -223,6 +224,8 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation + prefill_round_robin_balance: bool = False # Multi-node distributed serving dist_init_addr: Optional[str] = None @@ -623,12 +626,12 @@ class ServerArgs: if self.grammar_backend is None: self.grammar_backend = "xgrammar" + if self.dp_size == 1: + self.enable_dp_attention = False + # Data parallelism attention if self.enable_dp_attention: self.schedule_conservativeness = self.schedule_conservativeness * 0.3 - assert ( - self.dp_size > 1 - ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size " assert self.tp_size % self.dp_size == 0 self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size logger.warning( @@ -807,6 +810,13 @@ class ServerArgs: self.disable_radix_cache = True logger.warning("KV cache is forced as chunk cache for decode server") + + if self.dp_size > 1 and not is_in_ci(): + assert self.prefill_round_robin_balance, ( + "Prefill round robin balance is required when dp size > 1. " + "Please make sure that the prefill instance is launched with `--load-balance-method round_robin`" + " and `--prefill-round-robin-balance` is set for decode server." + ) elif self.disaggregation_mode == "prefill": if self.disaggregation_decode_tp is None: self.disaggregation_decode_tp = self.tp_size @@ -1384,6 +1394,12 @@ class ServerArgs: "minimum_tokens", ], ) + parser.add_argument( + "--prefill-round-robin-balance", + default=ServerArgs.prefill_round_robin_balance, + action="store_true", + help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.", + ) # Multi-node distributed serving parser.add_argument(