diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 4d66c18af..e6a6ad445 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver): mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr + self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( @@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver): self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.target_dp_group = bootstrap_room % self.prefill_dp_size + if self.data_parallel_rank is not None: + logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") + self.target_dp_group = self.data_parallel_rank + else: + self.target_dp_group = bootstrap_room % self.prefill_dp_size # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 7982f7b63..e206450b6 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -156,6 +156,7 @@ class DecodePreallocQueue: mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, + data_parallel_rank=req.data_parallel_rank, ) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 1e650753e..d080c8e2e 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver): mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.has_init = False diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index eb8ad44e2..5d55eb468 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver): mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr @@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver): self.session_id = self.kv_mgr.get_session_id() self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) self.conclude_state = None + self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( @@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver): self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.target_dp_group = self.bootstrap_room % self.prefill_dp_size + if self.data_parallel_rank is not None: + logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") + self.target_dp_group = self.data_parallel_rank + else: + self.target_dp_group = bootstrap_room % self.prefill_dp_size # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 3ed021a6b..f9a0e931c 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver): mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.started_transfer = False - super().__init__(mgr, bootstrap_addr, bootstrap_room) + super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: diff --git a/python/sglang/srt/entrypoints/EngineBase.py b/python/sglang/srt/entrypoints/EngineBase.py index c7dfafd41..9ac68faa7 100644 --- a/python/sglang/srt/entrypoints/EngineBase.py +++ b/python/sglang/srt/entrypoints/EngineBase.py @@ -23,6 +23,12 @@ class EngineBase(ABC): token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: Optional[bool] = None, + stream: Optional[bool] = None, + bootstrap_host: Optional[Union[List[str], str]] = None, + bootstrap_port: Optional[Union[List[int], int]] = None, + bootstrap_room: Optional[Union[List[int], int]] = None, + data_parallel_rank: Optional[int] = None, ) -> Union[Dict, Iterator[Dict]]: """Generate outputs based on given inputs.""" pass diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index c9ae3a3e0..75bccc9dc 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -167,11 +167,22 @@ class Engine(EngineBase): bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, + data_parallel_rank: Optional[int] = None, ) -> Union[Dict, Iterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ + if self.server_args.enable_dp_attention: + if data_parallel_rank is None: + logger.info("data_parallel_rank not provided, using default dispatch") + elif data_parallel_rank < 0: + raise ValueError("data_parallel_rank must be non-negative") + elif data_parallel_rank >= self.server_args.dp_size: + raise ValueError( + f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}" + ) + obj = GenerateReqInput( text=prompt, input_ids=input_ids, @@ -188,6 +199,7 @@ class Engine(EngineBase): bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, + data_parallel_rank=data_parallel_rank, ) loop = asyncio.get_event_loop() generator = self.tokenizer_manager.generate_request(obj, None) @@ -237,11 +249,24 @@ class Engine(EngineBase): bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, + data_parallel_rank: Optional[int] = None, ) -> Union[Dict, AsyncIterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ + + if self.server_args.enable_dp_attention: + if data_parallel_rank is None: + logger.info("data_parallel_rank not provided, using default dispatch") + elif data_parallel_rank < 0: + raise ValueError("data_parallel_rank must be non-negative") + elif data_parallel_rank >= self.server_args.dp_size: + raise ValueError( + f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]" + ) + + logger.info(f"data_parallel_rank: {data_parallel_rank}") obj = GenerateReqInput( text=prompt, input_ids=input_ids, @@ -257,6 +282,7 @@ class Engine(EngineBase): bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, + data_parallel_rank=data_parallel_rank, ) generator = self.tokenizer_manager.generate_request(obj, None) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 876472312..62c3800c2 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -248,12 +248,20 @@ class DataParallelController: def round_robin_scheduler(self, req: Req): if self.server_args.disaggregation_mode == "null": - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len( - self.workers - ) + if req.data_parallel_rank is not None: + logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") + self.workers[req.data_parallel_rank].send_pyobj(req) + else: + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) else: - self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) + if req.data_parallel_rank is not None: + logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") + self.workers[req.data_parallel_rank].send_pyobj(req) + else: + self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2e73bf2a4..40c220c4e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -106,6 +106,9 @@ class GenerateReqInput: bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None + # For data parallel rank routing + data_parallel_rank: Optional[int] = None + def contains_mm_input(self) -> bool: return has_valid_data(self.image_data) or has_valid_data(self.audio_data) @@ -417,6 +420,9 @@ class GenerateReqInput: bootstrap_room=( self.bootstrap_room[i] if self.bootstrap_room is not None else None ), + data_parallel_rank=( + self.data_parallel_rank if self.data_parallel_rank is not None else None + ), ) @@ -464,6 +470,9 @@ class TokenizedGenerateReqInput: bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None + # For data parallel rank routing + data_parallel_rank: Optional[int] = None + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c5a4ff31a..e61ac7aec 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -451,6 +451,7 @@ class Req: bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): # Input and output info self.rid = rid @@ -605,6 +606,9 @@ class Req: self.bootstrap_room: Optional[int] = bootstrap_room self.disagg_kv_sender: Optional[BaseKVSender] = None + # For data parallel rank routing + self.data_parallel_rank: Optional[int] = data_parallel_rank + # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. # After every chunk forward, we do the following: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e5f40a78e..27223a6a4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -949,6 +949,7 @@ class Scheduler( bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, bootstrap_room=recv_req.bootstrap_room, + data_parallel_rank=recv_req.data_parallel_rank, ) req.tokenizer = self.tokenizer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5c6033c11..d71bbdf07 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -570,6 +570,7 @@ class TokenizerManager: session_params=session_params, custom_logit_processor=obj.custom_logit_processor, return_hidden_states=obj.return_hidden_states, + data_parallel_rank=obj.data_parallel_rank, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput(