From 9ab72f9895c7ebd3b7cab06b76bc84e1976ff66b Mon Sep 17 00:00:00 2001 From: shaharmor98 <17088876+shaharmor98@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:47:26 +0300 Subject: [PATCH] add variable TP Decode > Prefill size support (#9960) Signed-off-by: Shahar Mor --- .../sglang/srt/disaggregation/common/conn.py | 3 - .../srt/disaggregation/mooncake/conn.py | 4 +- python/sglang/srt/disaggregation/nixl/conn.py | 192 ++++++++++++++++-- 3 files changed, 181 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index e7502d0c4..10b6093b9 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -168,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver): self.required_dst_info_num = 1 self.target_tp_ranks = [self.target_tp_rank] elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: - assert ( - self.kv_mgr.is_mla_backend - ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 0ad7280f9..f69d29622 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -459,7 +459,9 @@ class MooncakeKVManager(BaseKVManager): dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank else: # Send KVCache from 1 prefill instance to multiple decode instances - src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank + src_head_start_offset = ( + dst_tp_rank_in_group * dst_heads_per_rank + ) % src_heads_per_rank num_heads_to_send = dst_heads_per_rank dst_head_start_offset = 0 diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 1b427ee61..c911319ea 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -78,6 +78,9 @@ class KVArgsRegisterInfo: dst_kv_ptrs: list[int] dst_aux_ptrs: list[int] gpu_id: int + decode_tp_size: int + decode_tp_rank: int + dst_kv_item_len: int @classmethod def from_zmq(cls, msg: List[bytes]): @@ -90,6 +93,9 @@ class KVArgsRegisterInfo: dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), gpu_id=int(msg[7].decode("ascii")), + decode_tp_size=int(msg[8].decode("ascii")), + decode_tp_rank=int(msg[9].decode("ascii")), + dst_kv_item_len=int(msg[10].decode("ascii")), ) @@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager): self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens ): kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) - self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False) + self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM") logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}") if not self.kv_descs: raise Exception("NIXL memory registration failed for kv tensors") @@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ): aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) - self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False) + self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM") logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}") if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") @@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager): logger.debug( f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}" ) - src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False) - dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False) + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") # Transfer data xfer_handle = self.agent.initialize_xfer( "WRITE", @@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager): raise Exception("KVSender failed to post transfer") return xfer_handle + def send_kvcache_slice( + self, + peer_name: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int32], + dst_gpu_id: int, + notif: str, + prefill_tp_size: int, + decode_tp_size: int, + decode_tp_rank: int, + dst_kv_item_len: int, + ): + # Get configuration from kv_args + local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size + dst_tp_rank_in_group = decode_tp_rank % decode_tp_size + num_kv_heads = self.kv_args.kv_head_num + + # Calculate head distribution + src_heads_per_rank = num_kv_heads + dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size + + src_kv_item_len = self.kv_args.kv_item_lens[0] + page_size = self.kv_args.page_size + + bytes_per_head_slice_to_send = ( + dst_kv_item_len // page_size // dst_heads_per_rank + ) + + # Determine which heads to send + if prefill_tp_size > decode_tp_size: + # Multiple prefill ranks to one decode rank + src_head_start_offset = 0 + num_heads_to_send = src_heads_per_rank + dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank + else: + # Send KVCache from 1 prefill instance to multiple decode instances + src_head_start_offset = ( + dst_tp_rank_in_group * dst_heads_per_rank + ) % src_heads_per_rank + num_heads_to_send = dst_heads_per_rank + dst_head_start_offset = 0 + + # Create transfer descriptors + src_addrs = [] + dst_addrs = [] + + bytes_per_token_on_prefill = src_kv_item_len // page_size + bytes_per_token_on_decode = dst_kv_item_len // page_size + + num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 + src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] + src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] + dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)] + dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)] + + # Calculate precise byte offset and length for the sub-slice within the token + src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send + dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send + heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send + + src_dst_ptr_pairs = [ + ( + src_k_ptrs[layer_id], + dst_k_ptrs[layer_id], + ) + for layer_id in range(len(src_k_ptrs)) + ] + [ + ( + src_v_ptrs[layer_id], + dst_v_ptrs[layer_id], + ) + for layer_id in range(len(src_v_ptrs)) + ] + + src_addrs = [] + dst_addrs = [] + + # Calculate strides for a single token slot + bytes_per_token_on_prefill = src_kv_item_len // page_size + bytes_per_token_on_decode = dst_kv_item_len // page_size + + for src_ptr, dst_ptr in src_dst_ptr_pairs: + for i in range(len(prefill_kv_indices)): + prefill_page_idx = int(prefill_kv_indices[i]) + decode_page_idx = int(dst_kv_indices[i]) + + # Get the starting addresses for the current src and dst pages + src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len + dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len + + # Iterate through each valid token slot within the current page + for token_slot_in_page in range(page_size): + # Calculate the start address of the current token slot + src_token_slot_start_addr = ( + src_page_start_addr + + token_slot_in_page * bytes_per_token_on_prefill + ) + dst_token_slot_start_addr = ( + dst_page_start_addr + + token_slot_in_page * bytes_per_token_on_decode + ) + + # Calculate final src and dst addresses by applying head-slice offsets + src_slice_addr = src_token_slot_start_addr + src_head_slice_offset + dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset + + src_addrs.append( + ( + src_slice_addr, + heads_bytes_per_token_to_send, + self.kv_args.gpu_id, + ) + ) + dst_addrs.append( + (dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id) + ) + + # Use NIXL agent for transfer + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii") + ) + if not xfer_handle: + raise Exception("Failed to create sliced KV transfer") + + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("Failed to post sliced KV transfer") + + return xfer_handle + def send_aux( self, peer_name: str, @@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager): decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len src_addrs = [(prefill_aux_addr, aux_item_len, 0)] dst_addrs = [(decode_aux_addr, aux_item_len, 0)] - src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False) - dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False) + src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM") # Transfer data xfer_handle = self.agent.initialize_xfer( "WRITE", @@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager): assert req.agent_name in self.decode_kv_args_table notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - ) + decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size + + if decode_tp_size == self.tp_size: + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + prefill_tp_size=self.tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=self.decode_kv_args_table[ + req.agent_name + ].decode_tp_rank, + dst_kv_item_len=self.decode_kv_args_table[ + req.agent_name + ].dst_kv_item_len, + ) + handles.append(kv_xfer_handle) # Only the last chunk we need to send the aux data. if is_last: @@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver): packed_kv_data_ptrs, packed_aux_data_ptrs, str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), + str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"), + str(self.kv_mgr.kv_args.engine_rank).encode("ascii"), + str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), ] )