diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 8e9487be6..bcb5dc7b9 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -27,6 +27,8 @@ class KVArgs: decode_tp_size: int # for pp prefill prefill_pp_size: int + kv_head_num: int + page_size: int class KVPoll: diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index faccd9d3d..1d6219089 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -103,6 +103,9 @@ class KVArgsRegisterInfo: mooncake_session_id: str dst_kv_ptrs: list[int] dst_aux_ptrs: list[int] + dst_tp_rank: int + dst_tp_size: int + dst_kv_item_len: int @classmethod def from_zmq(cls, msg: List[bytes]): @@ -113,6 +116,9 @@ class KVArgsRegisterInfo: mooncake_session_id=msg[3].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), + dst_tp_rank=int(msg[6].decode("ascii")), + dst_tp_size=int(msg[7].decode("ascii")), + dst_kv_item_len=int(msg[8].decode("ascii")), ) @@ -189,6 +195,8 @@ class MooncakeKVManager(BaseKVManager): self.session_pool_lock = threading.Lock() self.addr_to_rooms_tracker = defaultdict(set) self.connection_lock = threading.Lock() + self.required_prefill_info_num_map: Dict[int, int] = {} + self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set) # Heartbeat interval should be at least 2 seconds self.heartbeat_interval = max( float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 @@ -284,6 +292,163 @@ class MooncakeKVManager(BaseKVManager): return 0 + def send_kvcache_slice( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int64], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int64], + dst_tp_rank: int, + dst_tp_size: int, + dst_kv_item_len: int, + executor: concurrent.futures.ThreadPoolExecutor, + ): + """ + Sends KV cache slices from this Prefill rank to a target Decode rank, + supporting generic M-to-N TP size configurations. + + NOTE: This implementation calls the transfer engine for each token slot within + each page to ensure correctness for any page_size and head-slicing configuration. + This may introduce performance overhead (increased TTFT) for long sequences. + """ + # rank/kv_head config + local_tp_rank = self.kv_args.engine_rank + local_tp_size = self.tp_size // self.dp_size + num_kv_heads = self.kv_args.kv_head_num + num_layers = len(self.kv_args.kv_data_ptrs) + page_size = self.kv_args.page_size + + heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size + heads_per_prefill_rank = num_kv_heads + decode_global_head_start = dst_tp_rank * heads_per_decode_rank + prefill_global_head_start = local_tp_rank * heads_per_prefill_rank + bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size + + # decode config + decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)] + + if local_tp_size > dst_tp_size: + src_head_offset = 0 + num_heads_to_send = heads_per_prefill_rank + dst_head_offset = prefill_global_head_start - decode_global_head_start + else: + src_head_offset = decode_global_head_start - prefill_global_head_start + num_heads_to_send = heads_per_decode_rank + dst_head_offset = 0 + + layer_transfer_params = [] + for layer_id in range(num_layers): + item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id] + + # Page stride on the target Decode rank for its slice pages + item_len_of_decode_rank_page = decode_rank_item_lens[layer_id] + + if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0: + logger.error( + f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}" + ) + return -1 + + # Calculate precise byte offset and length for the sub-slice within Prefill page data + src_slice_offset = src_head_offset * bytes_per_head + dst_slice_offset = dst_head_offset * bytes_per_head + slice_lens_per_page = num_heads_to_send * bytes_per_head + + # Sanity check: The data sub-slice we intend to send should fit into D_n's page. + # This means slice_lens_per_page <= item_len_of_decode_rank_page + if slice_lens_per_page > item_len_of_decode_rank_page: + logger.error( + f"[{mooncake_session_id}] Layer {layer_id}: " + f"slice size ({slice_lens_per_page}) exceeds " + f"target page size ({item_len_of_decode_rank_page})" + ) + return -1 + layer_transfer_params.append( + ( + self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads) + dst_kv_ptrs[ + layer_id + ], # Decode base ptr (for its slice for this layer) + item_len_of_prefill_rank_page, # Prefill page size (all heads)2048 + item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024 + src_slice_offset, # Offset to slice data in Prefill page + dst_slice_offset, # Offset to slice data in Decode page + slice_lens_per_page, # Length of slice data per page (actual data to send) + ) + ) + + def process_layer_tp_aware(layer_params): + ( + src_ptr, + dst_ptr, + src_item_len, + dst_item_len, + src_offset, + dst_offset, + slice_lens_per_page, + ) = layer_params + src_addr_list = [] + dst_addr_list = [] + length_list = [] + + # Calculate strides for a single token slot + bytes_per_token_on_prefill = src_item_len // page_size + bytes_per_token_on_decode = dst_item_len // page_size + + for i in range(len(prefill_kv_indices)): + prefill_page_idx = int(prefill_kv_indices[i]) + decode_page_idx = int(dst_kv_indices[i]) + + # Get the starting memory address for the current source and destination pages + src_page_start_addr = src_ptr + prefill_page_idx * src_item_len + dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len + + # Iterate through each valid token slot within the current page + for token_slot_in_page in range(page_size): + # Calculate start address of the current token slot + src_token_slot_start_addr = ( + src_page_start_addr + + token_slot_in_page * bytes_per_token_on_prefill + ) + dst_token_slot_start_addr = ( + dst_page_start_addr + + token_slot_in_page * bytes_per_token_on_decode + ) + + # Calculate final source and destination addresses by applying head-slice offsets + src_slice_addr = src_token_slot_start_addr + src_offset + dst_slice_addr = dst_token_slot_start_addr + dst_offset + + src_addr_list.append(src_slice_addr) + dst_addr_list.append(dst_slice_addr) + length_list.append(slice_lens_per_page) + + logger.debug( + f"SYNC: sid={mooncake_session_id}, " + f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}" + ) + + return self.engine.batch_transfer_sync( + mooncake_session_id, src_addr_list, dst_addr_list, length_list + ) + + futures = [ + executor.submit( + process_layer_tp_aware, + layer_params, + ) + for layer_params in layer_transfer_params + ] + + for future in concurrent.futures.as_completed(futures): + status = future.result() + if status != 0: + for f in futures: + f.cancel() + return status + + return 0 + def send_aux( self, mooncake_session_id: str, @@ -308,7 +473,7 @@ class MooncakeKVManager(BaseKVManager): ) def sync_status_to_decode_endpoint( - self, remote: str, dst_port: int, room: int, status: int + self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int ): if ":" in remote: remote = remote.split(":")[0] @@ -316,6 +481,7 @@ class MooncakeKVManager(BaseKVManager): [ str(room).encode("ascii"), str(status).encode("ascii"), + str(prefill_rank).encode("ascii"), ] ) @@ -332,6 +498,7 @@ class MooncakeKVManager(BaseKVManager): ) polls = [] dst_ranks_infos = [] + local_rank = self.kv_args.engine_rank for req in reqs_to_be_processed: if not req.is_dummy: # Early exit if the request has failed @@ -347,6 +514,7 @@ class MooncakeKVManager(BaseKVManager): req.dst_port, req.room, KVPoll.Failed, + local_rank, ) break @@ -364,15 +532,31 @@ class MooncakeKVManager(BaseKVManager): f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" ) - ret = self.send_kvcache( - req.mooncake_session_id, - kv_chunk.prefill_kv_indices, - self.decode_kv_args_table[ - req.mooncake_session_id - ].dst_kv_ptrs, - chunked_dst_kv_indice, - executor, + target_rank_registration_info: KVArgsRegisterInfo = ( + self.decode_kv_args_table[req.mooncake_session_id] ) + local_tp_size = self.tp_size // self.dp_size + if self.is_mla_backend or ( + local_tp_size == target_rank_registration_info.dst_tp_size + ): + ret = self.send_kvcache( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_rank_registration_info.dst_kv_ptrs, + chunked_dst_kv_indice, + executor, + ) + else: + ret = self.send_kvcache_slice( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_rank_registration_info.dst_kv_ptrs, + chunked_dst_kv_indice, + target_rank_registration_info.dst_tp_rank, + target_rank_registration_info.dst_tp_size, + target_rank_registration_info.dst_kv_item_len, + executor, + ) if ret != 0: with self.session_lock: self.session_failures[req.mooncake_session_id] += 1 @@ -388,7 +572,11 @@ class MooncakeKVManager(BaseKVManager): ) self.update_status(kv_chunk.room, KVPoll.Failed) self.sync_status_to_decode_endpoint( - req.endpoint, req.dst_port, req.room, KVPoll.Failed + req.endpoint, + req.dst_port, + req.room, + KVPoll.Failed, + local_rank, ) break @@ -413,7 +601,7 @@ class MooncakeKVManager(BaseKVManager): self.update_status(req.room, status) for endpoint, dst_port, room in dst_ranks_infos: self.sync_status_to_decode_endpoint( - endpoint, dst_port, room, status + endpoint, dst_port, room, status, local_rank ) else: # Dummy request means the decode instance is not used, so its status can be marked as success directly @@ -479,15 +667,33 @@ class MooncakeKVManager(BaseKVManager): def decode_thread(): while True: - (bootstrap_room, status) = self.server_socket.recv_multipart() + (bootstrap_room, status, prefill_rank) = ( + self.server_socket.recv_multipart() + ) status = int(status.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii")) - if status == KVPoll.Failed: + prefill_rank = int(prefill_rank.decode("ascii")) + + if status == KVPoll.Success: + # record arrived prefill_rank + self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank) + expected_prefill_num = self.required_prefill_info_num_map[ + bootstrap_room + ] + arrived_prefill_num = len( + self.decode_kv_arrive_state[bootstrap_room] + ) + if ( + self.is_mla_backend + or arrived_prefill_num == expected_prefill_num + ): + self.update_status(bootstrap_room, KVPoll.Success) + elif status == KVPoll.Failed: self.record_failure( bootstrap_room, f"Failed to get kvcache from prefill instance, it might be dead", ) - self.update_status(bootstrap_room, status) + self.update_status(bootstrap_room, status) def heartbeat_checker(): while True: @@ -713,7 +919,10 @@ class MooncakeKVSender(BaseKVSender): if not is_last: self.kv_mgr.add_transfer_request( - self.bootstrap_room, kv_indices, index_slice, False + self.bootstrap_room, + kv_indices, + index_slice, + False, ) else: self.kv_mgr.add_transfer_request( @@ -822,23 +1031,26 @@ class MooncakeKVReceiver(BaseKVReceiver): self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank ) self.required_dst_info_num = 1 + self.required_prefill_info_num = 1 self.target_tp_ranks = [self.target_tp_rank] elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: - assert ( - self.kv_mgr.is_mla_backend - ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" + if not self.kv_mgr.is_mla_backend: + logger.warning_once( + "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " + ) self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) self.required_dst_info_num = ( local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank ) + self.required_prefill_info_num = 1 self.target_tp_ranks = [self.target_tp_rank] else: - assert ( - self.kv_mgr.is_mla_backend - ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" - + if not self.kv_mgr.is_mla_backend: + logger.warning_once( + "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " + ) # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; self.target_tp_ranks = [ rank @@ -855,6 +1067,9 @@ class MooncakeKVReceiver(BaseKVReceiver): # or the KVPoll will never be set correctly self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 + self.required_prefill_info_num = ( + prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank + ) if self.data_parallel_rank is not None: logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") @@ -862,6 +1077,9 @@ class MooncakeKVReceiver(BaseKVReceiver): else: self.target_dp_group = bootstrap_room % self.prefill_dp_size + self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = ( + self.required_prefill_info_num + ) # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" @@ -875,11 +1093,15 @@ class MooncakeKVReceiver(BaseKVReceiver): self.target_dp_group, ) if bootstrap_info is not None: - # NOTE: only support MLA for now: select one prefill rank as real rank - bootstrap_info["is_dummy"] = not bool( - target_tp_rank == self.target_tp_rank - or self.target_tp_rank is None - ) + if self.kv_mgr.is_mla_backend: + # MLA :select one prefill rank as real rank + bootstrap_info["is_dummy"] = not bool( + target_tp_rank == self.target_tp_rank + or self.target_tp_rank is None + ) + else: + # no-MLA:select all prefill ranks + bootstrap_info["is_dummy"] = False logger.debug( f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" ) @@ -951,6 +1173,12 @@ class MooncakeKVReceiver(BaseKVReceiver): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) + tp_rank = self.kv_mgr.kv_args.engine_rank + tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size + kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] + dst_tp_rank = str(tp_rank).encode("ascii") + dst_tp_size = str(tp_size).encode("ascii") + dst_kv_item_len = str(kv_item_len).encode("ascii") sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: @@ -962,6 +1190,9 @@ class MooncakeKVReceiver(BaseKVReceiver): self.session_id.encode("ascii"), packed_kv_data_ptrs, packed_aux_data_ptrs, + dst_tp_rank, + dst_tp_size, + dst_kv_item_len, ] ) @@ -1009,6 +1240,8 @@ class MooncakeKVReceiver(BaseKVReceiver): def clear(self) -> None: if self.bootstrap_room in self.kv_mgr.request_status: self.kv_mgr.request_status.pop(self.bootstrap_room) + self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room) + self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room) def failure_exception(self): # Explicitly set the status to failure since this request has failed in another rank diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 7b78c3776..231b648cc 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -122,6 +122,9 @@ class PrefillBootstrapQueue: kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens + if not self.is_mla_backend: + kv_args.kv_head_num = self.token_to_kv_pool.head_num + kv_args.page_size = self.token_to_kv_pool.page_size kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos()