From 8f2cd177afbc863c8648479121e65ee8864483f0 Mon Sep 17 00:00:00 2001 From: shaharmor98 <17088876+shaharmor98@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:24:32 +0300 Subject: [PATCH] add code pp support for nixl (#11375) Signed-off-by: Shahar Mor --- python/sglang/srt/disaggregation/nixl/conn.py | 78 +++++++++++++------ 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index b76a1cb15..df5f9e49c 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager): logger.debug(f"sending kvcache to {peer_name} with notif {notif}") # Make descs - num_layers = len(self.kv_args.kv_data_ptrs) + if self.is_mla_backend: + src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( + self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + ) + kv_item_len = self.kv_args.kv_item_lens[0] + layers_params = [ + ( + src_kv_ptrs[layer_id], + dst_kv_ptrs[layer_id], + kv_item_len, + ) + for layer_id in range(layers_current_pp_stage) + ] + else: + src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( + self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + ) + + kv_item_len = self.kv_args.kv_item_lens[0] + layers_params = [ + ( + src_k_ptrs[layer_id], + dst_k_ptrs[layer_id], + kv_item_len, + ) + for layer_id in range(layers_current_pp_stage) + ] + [ + ( + src_v_ptrs[layer_id], + dst_v_ptrs[layer_id], + kv_item_len, + ) + for layer_id in range(layers_current_pp_stage) + ] + src_addrs = [] dst_addrs = [] - for layer_id in range(num_layers): - src_ptr = self.kv_args.kv_data_ptrs[layer_id] - dst_ptr = dst_kv_ptrs[layer_id] - item_len = self.kv_args.kv_item_lens[layer_id] - + for src_ptr, dst_ptr, item_len in layers_params: for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): src_addr = src_ptr + int(prefill_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len @@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager): num_heads_to_send = dst_heads_per_rank dst_head_start_offset = 0 + src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( + self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + ) # Create transfer descriptors src_addrs = [] dst_addrs = [] @@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager): bytes_per_token_on_prefill = src_kv_item_len // page_size bytes_per_token_on_decode = dst_kv_item_len // page_size - num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 - src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] - src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] - dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)] - dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)] - # Calculate precise byte offset and length for the sub-slice within the token src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send @@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager): src_k_ptrs[layer_id], dst_k_ptrs[layer_id], ) - for layer_id in range(len(src_k_ptrs)) + for layer_id in range(layers_current_pp_stage) ] + [ ( src_v_ptrs[layer_id], dst_v_ptrs[layer_id], ) - for layer_id in range(len(src_v_ptrs)) + for layer_id in range(layers_current_pp_stage) ] src_addrs = [] @@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager): dst_aux_index: int, notif: str, ): - # Make descs - aux_item_len = self.kv_args.aux_item_lens[0] - prefill_aux_addr = ( - self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len - ) - decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len - src_addrs = [(prefill_aux_addr, aux_item_len, 0)] - dst_addrs = [(decode_aux_addr, aux_item_len, 0)] + src_addrs = [] + dst_addrs = [] + + prefill_aux_ptrs = self.kv_args.aux_data_ptrs + prefill_aux_item_lens = self.kv_args.aux_item_lens + + for i, _ in enumerate(dst_aux_ptrs): + length = prefill_aux_item_lens[i] + src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index + dst_addr = dst_aux_ptrs[i] + length * dst_aux_index + src_addrs.append((src_addr, length, 0)) + dst_addrs.append((dst_addr, length, 0)) + src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM") dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM") # Transfer data @@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager): handles.append(kv_xfer_handle) # Only the last chunk we need to send the aux data. - if is_last: + if is_last and self.pp_group.is_last_rank: assert aux_index is not None aux_xfer_handle = self.send_aux( req.agent_name,