diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index e6436952d..9e8fa476f 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -109,6 +109,13 @@ class MooncakeKVManager(BaseKVManager): # for p/d multi node infer self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr + self.tp_size = server_args.tp_size + self.dp_size = server_args.dp_size + self.enable_dp_attention = server_args.enable_dp_attention + if not server_args.enable_dp_attention and server_args.dp_size != 1: + raise ValueError( + "If dp_attention is not enabled, dp size must be 1 in disaggregation mode." + ) self.request_status: Dict[int, KVPoll] = {} self.rank_port = None self.server_socket = zmq.Context().socket(zmq.PULL) @@ -121,6 +128,7 @@ class MooncakeKVManager(BaseKVManager): elif self.disaggregation_mode == DisaggregationMode.DECODE: self.start_decode_thread() self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} + self.prefill_dp_size_table: Dict[str, int] = {} else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" @@ -331,6 +339,8 @@ class MooncakeKVManager(BaseKVManager): url = f"http://{bootstrap_server_url}/route" payload = { "role": "Prefill", + "tp_size": self.tp_size, + "dp_size": self.dp_size, "rank_ip": get_local_ip_by_remote(), "rank_port": self.rank_port, "engine_rank": self.kv_args.engine_rank, @@ -408,12 +418,41 @@ class MooncakeKVReceiver(BaseKVReceiver): self.session_id = self.kv_mgr.get_session_id() self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) + if not self.kv_mgr.enable_dp_attention: + # We assume dp_attention should be activated simultaneously for + # both prefill role and decode role. If the decode instance does + # not enable dp_attention, then dp_attention is not enabled on the + # prefill instance as well. Therefore, we should skip questioning + # the prefill dp size to reduce bootstrap overhead. + self.prefill_dp_size = 1 + elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: + self.prefill_dp_size, tp_size_per_dp_rank = ( + self._get_prefill_dp_size_from_server() + ) + # Currently, we don't allow prefill instance and decode instance to + # have different TP sizes per DP rank. + assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size + if self.prefill_dp_size is None: + logger.error( + f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}" + ) + else: + self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( + self.prefill_dp_size + ) + else: + self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ + self.bootstrap_addr + ] + # NOTE: key distinguished by bootstrap_addr and engine_rank + self.target_dp_group = bootstrap_room % self.prefill_dp_size bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" if bootstrap_key not in self.kv_mgr.connection_pool: self.bootstrap_info = self._get_bootstrap_info_from_server( - self.kv_mgr.kv_args.engine_rank + self.kv_mgr.kv_args.engine_rank, + self.target_dp_group, ) if self.bootstrap_info is None: logger.error( @@ -427,10 +466,10 @@ class MooncakeKVReceiver(BaseKVReceiver): assert self.bootstrap_info is not None self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) - def _get_bootstrap_info_from_server(self, engine_rank): + def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" response = requests.get(url) if response.status_code == 200: bootstrap_info = response.json() @@ -444,6 +483,25 @@ class MooncakeKVReceiver(BaseKVReceiver): logger.error(f"Error fetching prefill info from bootstrap: {e}") return None + def _get_prefill_dp_size_from_server(self) -> int: + """Fetch the prefill parallel info from the bootstrap server.""" + try: + url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" + response = requests.get(url) + if response.status_code == 200: + prefill_parallel_info = response.json() + return int(prefill_parallel_info["prefill_dp_size"]), int( + prefill_parallel_info["tp_size_per_dp_rank"] + ) + else: + logger.error( + f"Failed to get prefill parallel info: {response.status_code}, {response.text}" + ) + return None + except Exception as e: + logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") + return None + @classmethod def _connect(cls, endpoint: str): with cls._global_lock: @@ -497,7 +555,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self.store = dict() self.lock = asyncio.Lock() self._setup_routes() - self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {} + self.dp_size = None + self.tp_size_per_dp_rank = None + self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {} # Start bootstrap server self.thread = threading.Thread(target=self._run_server, daemon=True) @@ -523,35 +583,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): async def _handle_route_put(self, request: web.Request): data = await request.json() role = data["role"] + tp_size = data["tp_size"] + dp_size = data["dp_size"] rank_ip = data["rank_ip"] rank_port = int(data["rank_port"]) engine_rank = int(data["engine_rank"]) + if self.dp_size is None: + self.dp_size = dp_size + + tp_size_per_dp_rank = tp_size // dp_size + if self.tp_size_per_dp_rank == None: + self.tp_size_per_dp_rank = tp_size_per_dp_rank + # Add lock to make sure thread-safe if role == "Prefill": - self.prefill_port_table[engine_rank] = { + dp_group = engine_rank // tp_size_per_dp_rank + tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank + + async with self.lock: + if dp_group not in self.prefill_port_table: + self.prefill_port_table[dp_group] = {} + + self.prefill_port_table[dp_group][tp_rank_in_dp_group] = { "rank_ip": rank_ip, "rank_port": rank_port, } logger.debug( - f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) async def _handle_route_get(self, request: web.Request): engine_rank = request.query.get("engine_rank") - if not engine_rank: - return web.Response(text="Missing rank", status=400) + target_dp_group = request.query.get("target_dp_group") + if not engine_rank or not target_dp_group: + return web.Response(text="Missing inputs for bootstrap server.", status=400) + + # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size + if int(engine_rank) == -1 and int(target_dp_group) == -1: + prefill_parallel_info = { + "prefill_dp_size": self.dp_size, + "tp_size_per_dp_rank": self.tp_size_per_dp_rank, + } + return web.json_response(prefill_parallel_info, status=200) # Find corresponding prefill info + tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank + async with self.lock: - bootstrap_info = self.prefill_port_table.get(int(engine_rank)) + bootstrap_info = self.prefill_port_table[int(target_dp_group)][ + tp_rank_in_dp_group + ] if bootstrap_info is not None: return web.json_response(bootstrap_info, status=200) else: - return web.Response(text="Not Found", status=404) + return web.Response(text="Bootstrap info not Found", status=404) def _run_server(self): try: