[PD] Extract the PP transfer layer calculate logic from Mooncake to Common backend (#10565)

Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
Shangming Cai
2025-09-28 00:10:41 +08:00
committed by GitHub
parent e23e280e16
commit 9c339d6b47
2 changed files with 41 additions and 39 deletions

View File

@@ -95,14 +95,6 @@ class CommonKVManager(BaseKVManager):
def _bind_server_socket(self): def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
@@ -156,6 +148,33 @@ class CommonKVManager(BaseKVManager):
socket.connect(endpoint) socket.connect(endpoint)
return socket return socket
def get_mha_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
num_kv_layers = len(src_kv_ptrs) // 2
end_layer = start_layer + num_kv_layers
dst_num_total_layers = len(dst_kv_ptrs) // 2
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
layers_current_pp_stage = len(src_k_ptrs)
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
def get_mla_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(src_kv_ptrs)
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
layers_current_pp_stage = len(src_kv_ptrs)
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
class CommonKVSender(BaseKVSender): class CommonKVSender(BaseKVSender):

View File

@@ -264,12 +264,10 @@ class MooncakeKVManager(CommonKVManager):
layers_params = None layers_params = None
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend: if self.is_mla_backend:
src_kv_ptrs = self.kv_args.kv_data_ptrs src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
layers_per_pp_stage = len(src_kv_ptrs) self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] )
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
@@ -277,18 +275,12 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs[layer_id], dst_kv_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] ]
else: else:
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
dst_num_total_layers = num_kv_layers * self.pp_size self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] )
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
@@ -296,14 +288,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs[layer_id], dst_k_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
dst_v_ptrs[layer_id], dst_v_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] ]
assert layers_params is not None assert layers_params is not None
@@ -401,18 +393,9 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send = dst_heads_per_rank num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0 dst_head_start_offset = 0
# pp is not supported on the decode side yet src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
dst_num_total_layers = num_kv_layers * self.pp_size )
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token # Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
@@ -438,7 +421,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset, dst_head_slice_offset,
heads_bytes_per_token_to_send, heads_bytes_per_token_to_send,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
@@ -449,7 +432,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset, dst_head_slice_offset,
heads_bytes_per_token_to_send, heads_bytes_per_token_to_send,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] ]
def process_layer_tp_aware(layer_params): def process_layer_tp_aware(layer_params):