From d98a4913eae3a38a879bdcdc8d9a3fe6c28b85c5 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Tue, 5 Aug 2025 11:18:11 +0800 Subject: [PATCH] [PD] Refactor parallel sizes and add pp support for mooncake (#8571) Signed-off-by: Shangming Cai --- python/sglang/srt/disaggregation/base/conn.py | 7 +- python/sglang/srt/disaggregation/decode.py | 7 +- .../srt/disaggregation/mooncake/conn.py | 380 +++++++++++------- python/sglang/srt/disaggregation/prefill.py | 2 + 4 files changed, 257 insertions(+), 139 deletions(-) diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index bcb5dc7b9..d37575dcf 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -25,10 +25,13 @@ class KVArgs: gpu_id: int # for different tp decode_tp_size: int - # for pp prefill - prefill_pp_size: int kv_head_num: int page_size: int + # for pp prefill + prefill_pp_size: int + pp_rank: int + # for system dp + system_dp_rank: int class KVPoll: diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index febb827fa..09d0b1310 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import ( poll_and_all_reduce, prepare_abort, ) +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -184,9 +185,13 @@ class DecodePreallocQueue: kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) kv_args = kv_args_class() - attn_tp_size = self.tp_size // self.dp_size + attn_tp_size = get_attention_tp_size() kv_args.engine_rank = self.tp_rank % (attn_tp_size) + kv_args.decode_tp_size = attn_tp_size + # Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0 + kv_args.pp_rank = 0 + kv_args.system_dp_rank = self.scheduler.dp_rank kv_args.prefill_pp_size = self.prefill_pp_size kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index d366b2791..25188c6a8 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import ( ) from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.layers.dp_attention import ( + get_attention_dp_rank, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, @@ -113,7 +119,7 @@ class KVArgsRegisterInfo: dst_kv_ptrs: list[int] dst_aux_ptrs: list[int] dst_tp_rank: int - dst_tp_size: int + dst_attn_tp_size: int dst_kv_item_len: int @classmethod @@ -126,7 +132,7 @@ class KVArgsRegisterInfo: dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_tp_rank=int(msg[6].decode("ascii")), - dst_tp_size=int(msg[7].decode("ascii")), + dst_attn_tp_size=int(msg[7].decode("ascii")), dst_kv_item_len=int(msg[8].decode("ascii")), ) @@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager): # for p/d multi node infer self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr - self.tp_size = server_args.tp_size - self.dp_size = server_args.dp_size - self.enable_dp_attention = server_args.enable_dp_attention - if not server_args.enable_dp_attention and server_args.dp_size != 1: - raise ValueError( - "If dp_attention is not enabled, dp size must be 1 in disaggregation mode." - ) + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + self.attn_dp_size = get_attention_dp_size() + self.attn_dp_rank = get_attention_dp_rank() + self.system_dp_size = ( + 1 if server_args.enable_dp_attention else server_args.dp_size + ) + self.system_dp_rank = ( + self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0 + ) + self.pp_size = server_args.pp_size + self.pp_rank = self.kv_args.pp_rank self.request_status: Dict[int, KVPoll] = {} self.rank_port = None self.server_socket = zmq.Context().socket(zmq.PULL) @@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager): ) self.start_decode_thread() self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} - self.prefill_tp_size_table: Dict[str, int] = {} + self.prefill_attn_tp_size_table: Dict[str, int] = {} self.prefill_dp_size_table: Dict[str, int] = {} + self.prefill_pp_size_table: Dict[str, int] = {} # If a timeout happens on the decode side, it means decode instances # fail to receive the KV Cache transfer done signal after bootstrapping. # These timeout requests should be aborted to release the tree cache. @@ -296,15 +308,53 @@ class MooncakeKVManager(BaseKVManager): prefill_kv_indices, dst_kv_indices ) - num_layers = len(self.kv_args.kv_data_ptrs) - layers_params = [ - ( - self.kv_args.kv_data_ptrs[layer_id], - dst_kv_ptrs[layer_id], - self.kv_args.kv_item_lens[layer_id], - ) - for layer_id in range(num_layers) - ] + layers_params = None + + # pp is not supported on the decode side yet + if self.is_mla_backend: + src_kv_ptrs = self.kv_args.kv_data_ptrs + layers_per_pp_stage = len(src_kv_ptrs) + start_layer = self.pp_rank * layers_per_pp_stage + end_layer = start_layer + layers_per_pp_stage + dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] + kv_item_len = self.kv_args.kv_item_lens[0] + layers_params = [ + ( + src_kv_ptrs[layer_id], + dst_kv_ptrs[layer_id], + kv_item_len, + ) + for layer_id in range(layers_per_pp_stage) + ] + else: + num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 + src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] + src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] + layers_per_pp_stage = len(src_k_ptrs) + start_layer = self.pp_rank * layers_per_pp_stage + end_layer = start_layer + layers_per_pp_stage + dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] + dst_v_ptrs = dst_kv_ptrs[ + num_kv_layers + start_layer : num_kv_layers + end_layer + ] + kv_item_len = self.kv_args.kv_item_lens[0] + + layers_params = [ + ( + src_k_ptrs[layer_id], + dst_k_ptrs[layer_id], + kv_item_len, + ) + for layer_id in range(layers_per_pp_stage) + ] + [ + ( + src_v_ptrs[layer_id], + dst_v_ptrs[layer_id], + kv_item_len, + ) + for layer_id in range(layers_per_pp_stage) + ] + assert layers_params is not None # Worker function for processing a single layer def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: @@ -343,7 +393,7 @@ class MooncakeKVManager(BaseKVManager): dst_kv_ptrs: list[int], dst_kv_indices: npt.NDArray[np.int64], dst_tp_rank: int, - dst_tp_size: int, + dst_attn_tp_size: int, dst_kv_item_len: int, executor: concurrent.futures.ThreadPoolExecutor, ): @@ -356,23 +406,22 @@ class MooncakeKVManager(BaseKVManager): This may introduce performance overhead (increased TTFT) for long sequences. """ # Extract configuration - local_tp_size = self.tp_size // self.dp_size - local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size + local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size src_kv_item_len = self.kv_args.kv_item_lens[0] - dst_tp_rank_in_group = dst_tp_rank % dst_tp_size + dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size num_kv_heads = self.kv_args.kv_head_num num_layers = len(self.kv_args.kv_data_ptrs) page_size = self.kv_args.page_size # Calculate head distribution src_heads_per_rank = num_kv_heads - dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size + dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size bytes_per_head_slice_to_send = ( dst_kv_item_len // page_size // dst_heads_per_rank ) # Determine slicing parameters based on TP configuration - if local_tp_size > dst_tp_size: + if self.attn_tp_size > dst_attn_tp_size: # Send KVCache from multiple prefill instances to 1 decode instance src_head_start_offset = 0 num_heads_to_send = src_heads_per_rank @@ -383,35 +432,55 @@ class MooncakeKVManager(BaseKVManager): num_heads_to_send = dst_heads_per_rank dst_head_start_offset = 0 - layers_params = [] - for layer_id in range(num_layers): - # Calculate precise byte offset and length for the sub-slice within the token - src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send - dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send - heads_bytes_per_token_to_send = ( - num_heads_to_send * bytes_per_head_slice_to_send - ) + # pp is not supported on the decode side yet + num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 + src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] + src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] + layers_per_pp_stage = len(src_k_ptrs) + start_layer = self.pp_rank * layers_per_pp_stage + end_layer = start_layer + layers_per_pp_stage + dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] + dst_v_ptrs = dst_kv_ptrs[ + num_kv_layers + start_layer : num_kv_layers + end_layer + ] - # Sanity check: The data sub-slice to be sent should fit into the dst buffer. - # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size) - if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size): - logger.error( - f"[{mooncake_session_id}] Layer {layer_id}: " - f"slice size ({heads_bytes_per_token_to_send}) exceeds " - f"target token slot size ({dst_kv_item_len // page_size})" - ) - return -1 - layers_params.append( - ( - self.kv_args.kv_data_ptrs[layer_id], - dst_kv_ptrs[layer_id], - src_kv_item_len, - dst_kv_item_len, - src_head_slice_offset, - dst_head_slice_offset, - heads_bytes_per_token_to_send, - ) + # Calculate precise byte offset and length for the sub-slice within the token + src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send + dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send + heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send + + # Sanity check: The data sub-slice to be sent should fit into the dst buffer. + # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size) + if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size): + logger.error( + f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds " + f"target token slot size ({dst_kv_item_len // page_size})" ) + return -1 + + layers_params = [ + ( + src_k_ptrs[layer_id], + dst_k_ptrs[layer_id], + src_kv_item_len, + dst_kv_item_len, + src_head_slice_offset, + dst_head_slice_offset, + heads_bytes_per_token_to_send, + ) + for layer_id in range(layers_per_pp_stage) + ] + [ + ( + src_v_ptrs[layer_id], + dst_v_ptrs[layer_id], + src_kv_item_len, + dst_kv_item_len, + src_head_slice_offset, + dst_head_slice_offset, + heads_bytes_per_token_to_send, + ) + for layer_id in range(layers_per_pp_stage) + ] def process_layer_tp_aware(layer_params): ( @@ -562,9 +631,9 @@ class MooncakeKVManager(BaseKVManager): target_rank_registration_info: KVArgsRegisterInfo = ( self.decode_kv_args_table[req.mooncake_session_id] ) - local_tp_size = self.tp_size // self.dp_size if self.is_mla_backend or ( - local_tp_size == target_rank_registration_info.dst_tp_size + self.attn_tp_size + == target_rank_registration_info.dst_attn_tp_size ): ret = self.send_kvcache( req.mooncake_session_id, @@ -580,7 +649,7 @@ class MooncakeKVManager(BaseKVManager): target_rank_registration_info.dst_kv_ptrs, chunked_dst_kv_indice, target_rank_registration_info.dst_tp_rank, - target_rank_registration_info.dst_tp_size, + target_rank_registration_info.dst_attn_tp_size, target_rank_registration_info.dst_kv_item_len, executor, ) @@ -863,11 +932,16 @@ class MooncakeKVManager(BaseKVManager): url = f"http://{bootstrap_server_url}/route" payload = { "role": "Prefill", - "tp_size": self.tp_size, - "dp_size": self.dp_size, + "attn_tp_size": self.attn_tp_size, + "attn_tp_rank": self.attn_tp_rank, + "attn_dp_size": self.attn_dp_size, + "attn_dp_rank": self.attn_dp_rank, + "pp_size": self.pp_size, + "pp_rank": self.pp_rank, + "system_dp_size": self.system_dp_size, + "system_dp_rank": self.system_dp_rank, "rank_ip": self.local_ip, "rank_port": self.rank_port, - "engine_rank": self.kv_args.engine_rank, } try: @@ -890,10 +964,12 @@ class MooncakeKVManager(BaseKVManager): ] for k in keys_to_remove: del self.connection_pool[k] - if failed_bootstrap_addr in self.prefill_tp_size_table: - del self.prefill_tp_size_table[failed_bootstrap_addr] + if failed_bootstrap_addr in self.prefill_attn_tp_size_table: + del self.prefill_attn_tp_size_table[failed_bootstrap_addr] if failed_bootstrap_addr in self.prefill_dp_size_table: del self.prefill_dp_size_table[failed_bootstrap_addr] + if failed_bootstrap_addr in self.prefill_pp_size_table: + del self.prefill_pp_size_table[failed_bootstrap_addr] possible_affected_rooms = self.addr_to_rooms_tracker.get( failed_bootstrap_addr, [] @@ -915,7 +991,7 @@ class MooncakeKVManager(BaseKVManager): self.update_status(room, KVPoll.Failed) affected_rooms.append(room) logger.error( - f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests" + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" ) @@ -1042,10 +1118,16 @@ class MooncakeKVReceiver(BaseKVReceiver): self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: - self.prefill_tp_size, self.prefill_dp_size = ( - self._get_prefill_parallel_info_from_server() - ) - if self.prefill_tp_size is None or self.prefill_dp_size is None: + ( + self.prefill_attn_tp_size, + self.prefill_dp_size, + self.prefill_pp_size, + ) = self._get_prefill_parallel_info_from_server() + if ( + self.prefill_attn_tp_size is None + or self.prefill_dp_size is None + or self.prefill_pp_size is None + ): self.kv_mgr.record_failure( self.bootstrap_room, f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", @@ -1054,43 +1136,47 @@ class MooncakeKVReceiver(BaseKVReceiver): return else: logger.debug( - f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}" + f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}" ) - self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( - self.prefill_tp_size + self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = ( + self.prefill_attn_tp_size ) self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( self.prefill_dp_size ) + self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = ( + self.prefill_pp_size + ) else: - self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[ + self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[ self.bootstrap_addr ] self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ self.bootstrap_addr ] + self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[ + self.bootstrap_addr + ] # Currently, we don't allow prefill instance and decode instance to # have different TP sizes per DP rank, except for models using MLA. - local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size - prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size - if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank: + if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size: self.target_tp_rank = ( - self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size ) self.required_dst_info_num = 1 self.required_prefill_response_num = 1 self.target_tp_ranks = [self.target_tp_rank] - elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: + elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size: if not self.kv_mgr.is_mla_backend: logger.warning_once( "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " ) self.target_tp_rank = ( - self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank - ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) + self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size) self.required_dst_info_num = ( - local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank + self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size ) self.required_prefill_response_num = 1 self.target_tp_ranks = [self.target_tp_rank] @@ -1103,10 +1189,10 @@ class MooncakeKVReceiver(BaseKVReceiver): self.target_tp_ranks = [ rank for rank in range( - (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank) - * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), - (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1) - * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), + (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size) + * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), + (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1) + * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), ) ] @@ -1116,7 +1202,7 @@ class MooncakeKVReceiver(BaseKVReceiver): self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 self.required_prefill_response_num = ( - prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank + self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size ) if self.data_parallel_rank is not None: @@ -1136,31 +1222,31 @@ class MooncakeKVReceiver(BaseKVReceiver): if bootstrap_key not in self.kv_mgr.connection_pool: bootstrap_infos = [] for target_tp_rank in self.target_tp_ranks: - bootstrap_info = self._get_bootstrap_info_from_server( - target_tp_rank, - self.target_dp_group, - ) - if bootstrap_info is not None: - if self.kv_mgr.is_mla_backend: - # For MLA: target_tp_rank is the selected real rank, others are dummy ranks - bootstrap_info["is_dummy"] = not bool( - target_tp_rank == self.target_tp_rank - or self.target_tp_rank is None + for target_pp_rank in range(self.prefill_pp_size): + bootstrap_info = self._get_bootstrap_info_from_server( + target_tp_rank, self.target_dp_group, target_pp_rank + ) + if bootstrap_info is not None: + if self.kv_mgr.is_mla_backend: + # For MLA: target_tp_rank is the selected real rank, others are dummy ranks + bootstrap_info["is_dummy"] = not bool( + target_tp_rank == self.target_tp_rank + or self.target_tp_rank is None + ) + else: + # For non-MLA: all target_tp_ranks are selected real ranks + bootstrap_info["is_dummy"] = False + logger.debug( + f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}" ) + bootstrap_infos.append(bootstrap_info) else: - # For non-MLA: all target_tp_ranks are selected real ranks - bootstrap_info["is_dummy"] = False - logger.debug( - f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" - ) - bootstrap_infos.append(bootstrap_info) - else: - self.kv_mgr.record_failure( - self.bootstrap_room, - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}", - ) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - return + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}", + ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return self.bootstrap_infos = bootstrap_infos self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos @@ -1174,10 +1260,12 @@ class MooncakeKVReceiver(BaseKVReceiver): self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) - def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): + def _get_bootstrap_info_from_server( + self, engine_rank, target_dp_group, target_pp_rank + ): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}" response = requests.get(url, timeout=5) if response.status_code == 200: bootstrap_info = response.json() @@ -1191,24 +1279,28 @@ class MooncakeKVReceiver(BaseKVReceiver): logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]: + def _get_prefill_parallel_info_from_server( + self, + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: """Fetch the prefill parallel info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" + url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" response = requests.get(url) if response.status_code == 200: prefill_parallel_info = response.json() - return int(prefill_parallel_info["prefill_tp_size"]), int( - prefill_parallel_info["prefill_dp_size"] + return ( + int(prefill_parallel_info["prefill_attn_tp_size"]), + int(prefill_parallel_info["prefill_dp_size"]), + int(prefill_parallel_info["prefill_pp_size"]), ) else: logger.error( f"Failed to get prefill parallel info: {response.status_code}, {response.text}" ) - return None, None + return None, None, None except Exception as e: logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") - return None, None + return None, None, None def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: @@ -1218,11 +1310,11 @@ class MooncakeKVReceiver(BaseKVReceiver): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) + # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet tp_rank = self.kv_mgr.kv_args.engine_rank - tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] dst_tp_rank = str(tp_rank).encode("ascii") - dst_tp_size = str(tp_size).encode("ascii") + dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii") sock, lock = self._connect_to_bootstrap_server(bootstrap_info) @@ -1236,7 +1328,7 @@ class MooncakeKVReceiver(BaseKVReceiver): packed_kv_data_ptrs, packed_aux_data_ptrs, dst_tp_rank, - dst_tp_size, + dst_attn_tp_size, dst_kv_item_len, ] ) @@ -1347,10 +1439,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self.store = dict() self.lock = asyncio.Lock() self._setup_routes() - self.tp_size = None + self.pp_size = None + self.attn_tp_size = None self.dp_size = None - self.tp_size_per_dp_rank = None - self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {} + self.prefill_port_table: Dict[ + int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] + ] = {} # Start bootstrap server self.thread = threading.Thread(target=self._run_server, daemon=True) @@ -1380,37 +1474,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): async def _handle_route_put(self, request: web.Request): data = await request.json() role = data["role"] - tp_size = data["tp_size"] - dp_size = data["dp_size"] + attn_tp_size = data["attn_tp_size"] + attn_tp_rank = data["attn_tp_rank"] + attn_dp_size = data["attn_dp_size"] + attn_dp_rank = data["attn_dp_rank"] + pp_size = data["pp_size"] + pp_rank = data["pp_rank"] + system_dp_size = data["system_dp_size"] + system_dp_rank = data["system_dp_rank"] rank_ip = data["rank_ip"] rank_port = int(data["rank_port"]) - engine_rank = int(data["engine_rank"]) - if self.tp_size is None: - self.tp_size = tp_size + if self.attn_tp_size is None: + self.attn_tp_size = attn_tp_size if self.dp_size is None: - self.dp_size = dp_size + self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size - tp_size_per_dp_rank = tp_size // dp_size - if self.tp_size_per_dp_rank is None: - self.tp_size_per_dp_rank = tp_size_per_dp_rank + if self.pp_size is None: + self.pp_size = pp_size if role == "Prefill": - dp_group = engine_rank // tp_size_per_dp_rank - tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank + if system_dp_size == 1: + dp_group = attn_dp_rank + else: + dp_group = system_dp_rank # Add lock to make sure thread-safe async with self.lock: if dp_group not in self.prefill_port_table: self.prefill_port_table[dp_group] = {} + if attn_tp_rank not in self.prefill_port_table[dp_group]: + self.prefill_port_table[dp_group][attn_tp_rank] = {} - self.prefill_port_table[dp_group][tp_rank_in_dp_group] = { + self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = { "rank_ip": rank_ip, "rank_port": rank_port, } logger.debug( - f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) @@ -1418,14 +1520,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): async def _handle_route_get(self, request: web.Request): engine_rank = request.query.get("engine_rank") target_dp_group = request.query.get("target_dp_group") - if not engine_rank or not target_dp_group: + target_pp_rank = request.query.get("target_pp_rank") + if not engine_rank or not target_dp_group or not target_pp_rank: return web.Response(text="Missing inputs for bootstrap server.", status=400) # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size - if int(engine_rank) == -1 and int(target_dp_group) == -1: + if ( + int(engine_rank) == -1 + and int(target_dp_group) == -1 + and int(target_pp_rank) == -1 + ): prefill_parallel_info = { - "prefill_tp_size": self.tp_size, + "prefill_attn_tp_size": self.attn_tp_size, "prefill_dp_size": self.dp_size, + "prefill_pp_size": self.pp_size, } return web.json_response(prefill_parallel_info, status=200) @@ -1433,7 +1541,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): async with self.lock: bootstrap_info = self.prefill_port_table[int(target_dp_group)][ int(engine_rank) - ] + ][int(target_pp_rank)] if bootstrap_info is not None: return web.json_response(bootstrap_info, status=200) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index c15c1eff0..72cf9d3f9 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -103,6 +103,8 @@ class PrefillBootstrapQueue: kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) kv_args = kv_args_class() kv_args.engine_rank = self.tp_rank + kv_args.pp_rank = self.pp_rank + kv_args.system_dp_rank = self.scheduler.dp_rank kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size kv_args.prefill_pp_size = self.pp_size kv_data_ptrs, kv_data_lens, kv_item_lens = (