diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 42e0b2ae5..c91b4b813 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -37,6 +37,7 @@ class BaseKVManager(ABC): args: KVArgs, disaggregation_mode: DisaggregationMode, server_args: ServerArgs, + is_mla_backend: Optional[bool] = False, ): ... diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 4499afcfe..0aef85ba5 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import ( ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, + is_mla_backend, kv_to_page_indices, poll_and_all_reduce, ) @@ -87,6 +88,7 @@ class DecodePreallocQueue: self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache() + self.is_mla_backend = is_mla_backend(self.token_to_kv_pool) self.aux_dtype = aux_dtype self.metadata_buffers = metadata_buffers self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator @@ -131,7 +133,10 @@ class DecodePreallocQueue: kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager = kv_manager_class( - kv_args, DisaggregationMode.DECODE, self.scheduler.server_args + kv_args, + DisaggregationMode.DECODE, + self.scheduler.server_args, + self.is_mla_backend, ) return kv_manager diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 4cf1ad9f1..7226805bc 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -68,16 +68,28 @@ class TransferInfo: mooncake_session_id: str dst_kv_indices: npt.NDArray[np.int64] dst_aux_index: int + required_dst_info_num: int + is_dummy: bool @classmethod def from_zmq(cls, msg: List[bytes]): + if msg[4] == b"" and msg[5] == b"": + is_dummy = True + dst_kv_indices = np.array([], dtype=np.int64) + dst_aux_index = None + else: + dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64) + dst_aux_index = int(msg[5].decode("ascii")) + is_dummy = False return cls( room=int(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), dst_port=int(msg[2].decode("ascii")), mooncake_session_id=msg[3].decode("ascii"), - dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64), - dst_aux_index=int(msg[5].decode("ascii")), + dst_kv_indices=dst_kv_indices, + dst_aux_index=dst_aux_index, + required_dst_info_num=int(msg[6].decode("ascii")), + is_dummy=is_dummy, ) @@ -108,6 +120,7 @@ class MooncakeKVManager(BaseKVManager): args: KVArgs, disaggregation_mode: DisaggregationMode, server_args: ServerArgs, + is_mla_backend: Optional[bool] = False, ): self.kv_args = args self.engine = MooncakeTransferEngine( @@ -115,6 +128,7 @@ class MooncakeKVManager(BaseKVManager): gpu_id=self.kv_args.gpu_id, ib_device=self.kv_args.ib_device, ) + self.is_mla_backend = is_mla_backend self.disaggregation_mode = disaggregation_mode # for p/d multi node infer self.bootstrap_port = server_args.disaggregation_bootstrap_port @@ -132,7 +146,7 @@ class MooncakeKVManager(BaseKVManager): self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: self.transfer_queue = queue.Queue() - self.transfer_infos: Dict[int, TransferInfo] = {} + self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.start_prefill_thread() self._register_to_bootstrap() @@ -145,6 +159,7 @@ class MooncakeKVManager(BaseKVManager): elif self.disaggregation_mode == DisaggregationMode.DECODE: self.start_decode_thread() self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} + self.prefill_tp_size_table: Dict[str, int] = {} self.prefill_dp_size_table: Dict[str, int] = {} else: raise ValueError( @@ -218,7 +233,7 @@ class MooncakeKVManager(BaseKVManager): status = future.result() if status != 0: # Immediate shutdown on first error (existing tasks will finish) - executor.shutdown(wait=False) + self.executor.shutdown(wait=False) for f in futures: f.cancel() return status @@ -250,7 +265,7 @@ class MooncakeKVManager(BaseKVManager): self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart( [ str(room).encode("ascii"), - str(self.request_status[room]).encode("ascii"), + str(self.check_status(room)).encode("ascii"), ] ) @@ -264,8 +279,8 @@ class MooncakeKVManager(BaseKVManager): while True: waiting_req_bytes = self.server_socket.recv_multipart() room = waiting_req_bytes[0].decode("ascii") + mooncake_session_id = waiting_req_bytes[3].decode("ascii") if room == "None": - mooncake_session_id = waiting_req_bytes[3].decode("ascii") self.decode_kv_args_table[mooncake_session_id] = ( KVArgsRegisterInfo.from_zmq(waiting_req_bytes) ) @@ -273,53 +288,84 @@ class MooncakeKVManager(BaseKVManager): f"Register KVArgs from {mooncake_session_id} successfully" ) continue - room = int(room) - self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) + else: + required_dst_info_num = int(waiting_req_bytes[6].decode("ascii")) + room = int(room) + if room not in self.transfer_infos: + self.transfer_infos[room] = {} - # NOTE: after bootstrapping we can mark the req as waiting for input - self.request_status[room] = KVPoll.WaitingForInput + self.transfer_infos[room][mooncake_session_id] = ( + TransferInfo.from_zmq(waiting_req_bytes) + ) + # NOTE: after bootstrapping we can mark the req as waiting for input + if len(self.transfer_infos[room]) == required_dst_info_num: + self.update_status(room, KVPoll.WaitingForInput) def transfer_thread(): # TODO: Shall we use KVPoll.Transferring state? while True: try: kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01) - req = self.transfer_infos[kv_chunk.room] - chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] - assert len(chunked_dst_kv_indice) == len( - kv_chunk.prefill_kv_indices - ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" + reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values() + polls = [] + dst_ranks_infos = [] + for req in reqs_to_be_processed: + if not req.is_dummy: + chunked_dst_kv_indice = req.dst_kv_indices[ + kv_chunk.index_slice + ] + assert len(chunked_dst_kv_indice) == len( + kv_chunk.prefill_kv_indices + ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" - ret = self.send_kvcache( - req.mooncake_session_id, - kv_chunk.prefill_kv_indices, - self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs, - chunked_dst_kv_indice, - ) - if ret != 0: - self.request_status[kv_chunk.room] = KVPoll.Failed - self.sync_status_to_decode_endpoint( - req.endpoint, req.dst_port, req.room - ) - continue + ret = self.send_kvcache( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[ + req.mooncake_session_id + ].dst_kv_ptrs, + chunked_dst_kv_indice, + ) + if ret != 0: + self.update_status(kv_chunk.room, KVPoll.Failed) + self.sync_status_to_decode_endpoint( + req.endpoint, req.dst_port, req.room + ) + continue - if kv_chunk.is_last: - # Only the last chunk we need to send the aux data - ret = self.send_aux( - req.mooncake_session_id, - kv_chunk.prefill_aux_index, - self.decode_kv_args_table[ - req.mooncake_session_id - ].dst_aux_ptrs, - req.dst_aux_index, - ) - self.request_status[req.room] = ( - KVPoll.Success if ret == 0 else KVPoll.Failed - ) - self.sync_status_to_decode_endpoint( - req.endpoint, req.dst_port, req.room - ) - self.transfer_infos.pop(req.room) + if kv_chunk.is_last: + # Only the last chunk we need to send the aux data + ret = self.send_aux( + req.mooncake_session_id, + kv_chunk.prefill_aux_index, + self.decode_kv_args_table[ + req.mooncake_session_id + ].dst_aux_ptrs, + req.dst_aux_index, + ) + polls.append(True if ret == 0 else False) + dst_ranks_infos.append( + (req.endpoint, req.dst_port, req.room) + ) + + # Only sync status when all the dst ranks have received the kvcache + if len(polls) == req.required_dst_info_num: + self.update_status( + req.room, + KVPoll.Success if all(polls) else KVPoll.Failed, + ) + for endpoint, dst_port, room in dst_ranks_infos: + self.sync_status_to_decode_endpoint( + endpoint, dst_port, room + ) + else: + # Dummy request means the decode instance is not used, so its status can be marked as success directly + # Dummy request does not need to sync status to decode endpoint + if kv_chunk.is_last: + self.update_status(req.room, KVPoll.Success) + + if self.check_status(kv_chunk.room) == KVPoll.Success: + self.transfer_infos.pop(kv_chunk.room) except queue.Empty: continue @@ -336,7 +382,7 @@ class MooncakeKVManager(BaseKVManager): (bootstrap_room, status) = self.server_socket.recv_multipart() status = int(status.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii")) - self.request_status[bootstrap_room] = status + self.update_status(bootstrap_room, status) threading.Thread(target=decode_thread).start() @@ -360,11 +406,9 @@ class MooncakeKVManager(BaseKVManager): prefill_aux_index=aux_index, ) ) - self.request_status[bootstrap_room] = KVPoll.WaitingForInput + self.update_status(bootstrap_room, KVPoll.WaitingForInput) def check_status(self, bootstrap_room: int): - # TODO: do we really need the poll()? - return self.request_status[bootstrap_room] def update_status(self, bootstrap_room: int, status: KVPoll): @@ -469,54 +513,111 @@ class MooncakeKVReceiver(BaseKVReceiver): self.session_id = self.kv_mgr.get_session_id() self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) - if not self.kv_mgr.enable_dp_attention: - # We assume dp_attention should be activated simultaneously for - # both prefill role and decode role. If the decode instance does - # not enable dp_attention, then dp_attention is not enabled on the - # prefill instance as well. Therefore, we should skip questioning - # the prefill dp size to reduce bootstrap overhead. - self.prefill_dp_size = 1 - elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: - self.prefill_dp_size, tp_size_per_dp_rank = ( + if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: + self.prefill_tp_size, self.prefill_dp_size = ( self._get_prefill_dp_size_from_server() ) - # Currently, we don't allow prefill instance and decode instance to - # have different TP sizes per DP rank. - assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size - if self.prefill_dp_size is None: + if self.prefill_tp_size is None or self.prefill_dp_size is None: logger.error( - f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}" + f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}" ) else: + self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( + self.prefill_tp_size + ) self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( self.prefill_dp_size ) else: + self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[ + self.bootstrap_addr + ] self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ self.bootstrap_addr ] - # NOTE: key distinguished by bootstrap_addr and engine_rank + # Currently, we don't allow prefill instance and decode instance to + # have different TP sizes per DP rank, except for models using MLA. + local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size + prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size + if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank: + self.target_tp_rank = ( + self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + ) + self.required_dst_info_num = 1 + self.target_tp_ranks = [self.target_tp_rank] + elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: + assert ( + self.kv_mgr.is_mla_backend + ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" + self.target_tp_rank = ( + self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) + self.required_dst_info_num = ( + local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank + ) + self.target_tp_ranks = [self.target_tp_rank] + else: + assert ( + self.kv_mgr.is_mla_backend + ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" + + # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; + self.target_tp_ranks = [ + rank + for rank in range( + (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank) + * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), + (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1) + * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), + ) + ] + + # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain + # multiple connections in the connection pool and have to send dummy requests to other prefill ranks, + # or the KVPoll will never be set correctly + self.target_tp_rank = self.target_tp_ranks[0] + self.required_dst_info_num = 1 + self.target_dp_group = bootstrap_room % self.prefill_dp_size - bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" + + # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank + bootstrap_key = ( + f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" + ) if bootstrap_key not in self.kv_mgr.connection_pool: - self.bootstrap_info = self._get_bootstrap_info_from_server( - self.kv_mgr.kv_args.engine_rank, - self.target_dp_group, - ) - if self.bootstrap_info is None: + bootstrap_infos = [] + for target_tp_rank in self.target_tp_ranks: + bootstrap_info = self._get_bootstrap_info_from_server( + target_tp_rank, + self.target_dp_group, + ) + if bootstrap_info is not None: + # NOTE: only support MLA for now: select one prefill rank as real rank + bootstrap_info["is_dummy"] = not bool( + target_tp_rank == self.target_tp_rank + or self.target_tp_rank is None + ) + bootstrap_infos.append(bootstrap_info) + else: + logger.error( + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}" + ) + self.bootstrap_infos = bootstrap_infos + + if len(self.bootstrap_infos) == 0: logger.error( f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) else: - self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server self._register_kv_args() else: - self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] + self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] - assert self.bootstrap_info is not None + assert len(self.bootstrap_infos) > 0 self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): @@ -543,8 +644,8 @@ class MooncakeKVReceiver(BaseKVReceiver): response = requests.get(url) if response.status_code == 200: prefill_parallel_info = response.json() - return int(prefill_parallel_info["prefill_dp_size"]), int( - prefill_parallel_info["tp_size_per_dp_rank"] + return int(prefill_parallel_info["prefill_tp_size"]), int( + prefill_parallel_info["prefill_dp_size"] ) else: logger.error( @@ -556,28 +657,29 @@ class MooncakeKVReceiver(BaseKVReceiver): return None def _register_kv_args(self): - self.prefill_server_url = ( - f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" - ) - - packed_kv_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs - ) - packed_aux_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs - ) - sock, lock = self._connect("tcp://" + self.prefill_server_url) - with lock: - sock.send_multipart( - [ - "None".encode("ascii"), - get_local_ip_by_remote().encode("ascii"), - str(self.kv_mgr.rank_port).encode("ascii"), - self.session_id.encode("ascii"), - packed_kv_data_ptrs, - packed_aux_data_ptrs, - ] + for bootstrap_info in self.bootstrap_infos: + self.prefill_server_url = ( + f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" ) + packed_kv_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs + ) + packed_aux_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs + ) + + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( + [ + "None".encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.session_id.encode("ascii"), + packed_kv_data_ptrs, + packed_aux_data_ptrs, + ] + ) @classmethod def _connect(cls, endpoint: str): @@ -590,25 +692,28 @@ class MooncakeKVReceiver(BaseKVReceiver): return cls._socket_cache[endpoint], cls._socket_locks[endpoint] def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): - self.prefill_server_url = ( - f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" - ) - logger.debug( - f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" - ) - - sock, lock = self._connect("tcp://" + self.prefill_server_url) - with lock: - sock.send_multipart( - [ - str(self.bootstrap_room).encode("ascii"), - get_local_ip_by_remote().encode("ascii"), - str(self.kv_mgr.rank_port).encode("ascii"), - self.session_id.encode("ascii"), - kv_indices.tobytes(), - str(aux_index).encode("ascii"), - ] + for bootstrap_info in self.bootstrap_infos: + self.prefill_server_url = ( + f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" ) + logger.debug( + f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) + is_dummy = bootstrap_info["is_dummy"] + + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( + [ + str(self.bootstrap_room).encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.session_id.encode("ascii"), + kv_indices.tobytes() if not is_dummy else b"", + str(aux_index).encode("ascii") if not is_dummy else b"", + str(self.required_dst_info_num).encode("ascii"), + ] + ) def poll(self) -> KVPoll: return self.kv_mgr.check_status(self.bootstrap_room) @@ -624,6 +729,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self.store = dict() self.lock = asyncio.Lock() self._setup_routes() + self.tp_size = None self.dp_size = None self.tp_size_per_dp_rank = None self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {} @@ -658,6 +764,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): rank_port = int(data["rank_port"]) engine_rank = int(data["engine_rank"]) + if self.tp_size is None: + self.tp_size = tp_size + if self.dp_size is None: self.dp_size = dp_size @@ -693,17 +802,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size if int(engine_rank) == -1 and int(target_dp_group) == -1: prefill_parallel_info = { + "prefill_tp_size": self.tp_size, "prefill_dp_size": self.dp_size, - "tp_size_per_dp_rank": self.tp_size_per_dp_rank, } return web.json_response(prefill_parallel_info, status=200) # Find corresponding prefill info - tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank - async with self.lock: bootstrap_info = self.prefill_port_table[int(target_dp_group)][ - tp_rank_in_dp_group + int(engine_rank) ] if bootstrap_info is not None: diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 745e0c26f..78df3a5ad 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -132,6 +132,7 @@ class NixlKVManager(BaseKVManager): args: KVArgs, disaggregation_mode: DisaggregationMode, server_args: ServerArgs, + is_mla_backend: Optional[bool] = False, ): try: from nixl._api import nixl_agent diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index abcc707df..6af1928ff 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -34,6 +34,7 @@ from sglang.srt.disaggregation.utils import ( ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, + is_mla_backend, kv_to_page_indices, kv_to_page_num, poll_and_all_reduce, @@ -69,6 +70,7 @@ class PrefillBootstrapQueue: scheduler: Scheduler, ): self.token_to_kv_pool = token_to_kv_pool + self.is_mla_backend = is_mla_backend(token_to_kv_pool) self.aux_dtype = aux_dtype self.metadata_buffers = metadata_buffers @@ -112,7 +114,10 @@ class PrefillBootstrapQueue: kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager = kv_manager_class( - kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args + kv_args, + DisaggregationMode.PREFILL, + self.scheduler.server_args, + self.is_mla_backend, ) return kv_manager diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 40d63d6a3..1eb7dfec9 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -162,3 +162,9 @@ def register_disaggregation_server( warnings.warn( f"Failed to register disaggregation server: {res.status_code} {res.text}" ) + + +def is_mla_backend(target_kv_pool) -> bool: + from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool + + return isinstance(target_kv_pool, MLATokenToKVPool) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 15f9f091b..ccd5ba15c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -103,6 +103,7 @@ suites = { # TestFile("test_moe_deepep_eval_accuracy_large.py", 250), TestFile("test_disaggregation.py", 210), TestFile("test_local_attn.py", 250), + TestFile("test_disaggregation_different_tp.py", 210), TestFile("test_full_deepseek_v3.py", 250), TestFile("test_pp_single_node.py", 150), ], diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py new file mode 100644 index 000000000..116fdb175 --- /dev/null +++ b/test/srt/test_disaggregation_different_tp.py @@ -0,0 +1,151 @@ +import os +import subprocess +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_pd_server, + run_with_timeout, +) + + +class TestDisaggregationMooncakeDifferentTP(CustomTestCase): + @classmethod + def setUpClass(cls): + # Temporarily disable JIT DeepGEMM + cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_host = "127.0.0.1" + cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) + cls.lb_url = DEFAULT_URL_FOR_TEST + cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" + cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" + + run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + str(cls.base_port), + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--host", + cls.base_host, + "--port", + str(cls.base_port + 100), + "--tp", + "4", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--host", + cls.base_host, + "--port", + str(cls.base_port + 200), + "--tp", + "2", + "--base-gpu-id", + "4", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def wait_server_ready(cls, url, timeout=60): + start_time = time.time() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.time() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + # Restore JIT DeepGEMM environment variable + if cls.original_jit_deepgemm is not None: + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm + else: + os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None) + + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.lb_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.60) + + +if __name__ == "__main__": + unittest.main()