From 9c339d6b473b236018aa1ca87fbc0df0d1e43ec0 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Sun, 28 Sep 2025 00:10:41 +0800 Subject: [PATCH] [PD] Extract the PP transfer layer calculate logic from Mooncake to Common backend (#10565) Signed-off-by: Shangming Cai --- .../sglang/srt/disaggregation/common/conn.py | 35 +++++++++++---- .../srt/disaggregation/mooncake/conn.py | 45 ++++++------------- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 096a1db59..82876066f 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -95,14 +95,6 @@ class CommonKVManager(BaseKVManager): def _bind_server_socket(self): self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) - @cache - def _connect(self, endpoint: str, is_ipv6: bool = False): - socket = zmq.Context().socket(zmq.PUSH) - if is_ipv6: - socket.setsockopt(zmq.IPV6, 1) - socket.connect(endpoint) - return socket - def _register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: @@ -156,6 +148,33 @@ class CommonKVManager(BaseKVManager): socket.connect(endpoint) return socket + def get_mha_kv_ptrs_with_pp( + self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int] + ) -> Tuple[List[int], List[int], List[int], List[int], int]: + # pp is not supported on the decode side yet + start_layer = self.kv_args.prefill_start_layer + num_kv_layers = len(src_kv_ptrs) // 2 + end_layer = start_layer + num_kv_layers + dst_num_total_layers = len(dst_kv_ptrs) // 2 + src_k_ptrs = src_kv_ptrs[:num_kv_layers] + src_v_ptrs = src_kv_ptrs[num_kv_layers:] + dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] + dst_v_ptrs = dst_kv_ptrs[ + dst_num_total_layers + start_layer : dst_num_total_layers + end_layer + ] + layers_current_pp_stage = len(src_k_ptrs) + return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage + + def get_mla_kv_ptrs_with_pp( + self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int] + ) -> Tuple[List[int], List[int], int]: + # pp is not supported on the decode side yet + start_layer = self.kv_args.prefill_start_layer + end_layer = start_layer + len(src_kv_ptrs) + sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] + layers_current_pp_stage = len(src_kv_ptrs) + return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage + class CommonKVSender(BaseKVSender): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index f779e1fee..b6f12e46e 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -264,12 +264,10 @@ class MooncakeKVManager(CommonKVManager): layers_params = None # pp is not supported on the decode side yet - start_layer = self.kv_args.prefill_start_layer - end_layer = start_layer + len(self.kv_args.kv_data_ptrs) if self.is_mla_backend: - src_kv_ptrs = self.kv_args.kv_data_ptrs - layers_per_pp_stage = len(src_kv_ptrs) - dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] + src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( + self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + ) kv_item_len = self.kv_args.kv_item_lens[0] layers_params = [ ( @@ -277,18 +275,12 @@ class MooncakeKVManager(CommonKVManager): dst_kv_ptrs[layer_id], kv_item_len, ) - for layer_id in range(layers_per_pp_stage) + for layer_id in range(layers_current_pp_stage) ] else: - num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 - dst_num_total_layers = num_kv_layers * self.pp_size - src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] - src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] - layers_per_pp_stage = len(src_k_ptrs) - dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] - dst_v_ptrs = dst_kv_ptrs[ - dst_num_total_layers + start_layer : dst_num_total_layers + end_layer - ] + src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( + self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + ) kv_item_len = self.kv_args.kv_item_lens[0] layers_params = [ ( @@ -296,14 +288,14 @@ class MooncakeKVManager(CommonKVManager): dst_k_ptrs[layer_id], kv_item_len, ) - for layer_id in range(layers_per_pp_stage) + for layer_id in range(layers_current_pp_stage) ] + [ ( src_v_ptrs[layer_id], dst_v_ptrs[layer_id], kv_item_len, ) - for layer_id in range(layers_per_pp_stage) + for layer_id in range(layers_current_pp_stage) ] assert layers_params is not None @@ -401,18 +393,9 @@ class MooncakeKVManager(CommonKVManager): num_heads_to_send = dst_heads_per_rank dst_head_start_offset = 0 - # pp is not supported on the decode side yet - num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 - dst_num_total_layers = num_kv_layers * self.pp_size - src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] - src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:] - layers_per_pp_stage = len(src_k_ptrs) - start_layer = self.pp_rank * layers_per_pp_stage - end_layer = start_layer + layers_per_pp_stage - dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] - dst_v_ptrs = dst_kv_ptrs[ - dst_num_total_layers + start_layer : dst_num_total_layers + end_layer - ] + src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( + self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + ) # Calculate precise byte offset and length for the sub-slice within the token src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send @@ -438,7 +421,7 @@ class MooncakeKVManager(CommonKVManager): dst_head_slice_offset, heads_bytes_per_token_to_send, ) - for layer_id in range(layers_per_pp_stage) + for layer_id in range(layers_current_pp_stage) ] + [ ( src_v_ptrs[layer_id], @@ -449,7 +432,7 @@ class MooncakeKVManager(CommonKVManager): dst_head_slice_offset, heads_bytes_per_token_to_send, ) - for layer_id in range(layers_per_pp_stage) + for layer_id in range(layers_current_pp_stage) ] def process_layer_tp_aware(layer_params):