[PD] Extract the PP transfer layer calculate logic from Mooncake to Common backend (#10565)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -95,14 +95,6 @@ class CommonKVManager(BaseKVManager):
|
|||||||
def _bind_server_socket(self):
|
def _bind_server_socket(self):
|
||||||
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||||
|
|
||||||
@cache
|
|
||||||
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
|
||||||
socket = zmq.Context().socket(zmq.PUSH)
|
|
||||||
if is_ipv6:
|
|
||||||
socket.setsockopt(zmq.IPV6, 1)
|
|
||||||
socket.connect(endpoint)
|
|
||||||
return socket
|
|
||||||
|
|
||||||
def _register_to_bootstrap(self):
|
def _register_to_bootstrap(self):
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
if self.dist_init_addr:
|
if self.dist_init_addr:
|
||||||
@@ -156,6 +148,33 @@ class CommonKVManager(BaseKVManager):
|
|||||||
socket.connect(endpoint)
|
socket.connect(endpoint)
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
|
def get_mha_kv_ptrs_with_pp(
|
||||||
|
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
||||||
|
) -> Tuple[List[int], List[int], List[int], List[int], int]:
|
||||||
|
# pp is not supported on the decode side yet
|
||||||
|
start_layer = self.kv_args.prefill_start_layer
|
||||||
|
num_kv_layers = len(src_kv_ptrs) // 2
|
||||||
|
end_layer = start_layer + num_kv_layers
|
||||||
|
dst_num_total_layers = len(dst_kv_ptrs) // 2
|
||||||
|
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
|
||||||
|
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
|
||||||
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
|
dst_v_ptrs = dst_kv_ptrs[
|
||||||
|
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||||
|
]
|
||||||
|
layers_current_pp_stage = len(src_k_ptrs)
|
||||||
|
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
|
||||||
|
|
||||||
|
def get_mla_kv_ptrs_with_pp(
|
||||||
|
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
||||||
|
) -> Tuple[List[int], List[int], int]:
|
||||||
|
# pp is not supported on the decode side yet
|
||||||
|
start_layer = self.kv_args.prefill_start_layer
|
||||||
|
end_layer = start_layer + len(src_kv_ptrs)
|
||||||
|
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
|
layers_current_pp_stage = len(src_kv_ptrs)
|
||||||
|
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
|
||||||
|
|
||||||
|
|
||||||
class CommonKVSender(BaseKVSender):
|
class CommonKVSender(BaseKVSender):
|
||||||
|
|
||||||
|
|||||||
@@ -264,12 +264,10 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
layers_params = None
|
layers_params = None
|
||||||
|
|
||||||
# pp is not supported on the decode side yet
|
# pp is not supported on the decode side yet
|
||||||
start_layer = self.kv_args.prefill_start_layer
|
|
||||||
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
|
||||||
if self.is_mla_backend:
|
if self.is_mla_backend:
|
||||||
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
||||||
layers_per_pp_stage = len(src_kv_ptrs)
|
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||||
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
)
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
@@ -277,18 +275,12 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_kv_ptrs[layer_id],
|
dst_kv_ptrs[layer_id],
|
||||||
kv_item_len,
|
kv_item_len,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_per_pp_stage)
|
for layer_id in range(layers_current_pp_stage)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||||
dst_num_total_layers = num_kv_layers * self.pp_size
|
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
)
|
||||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
|
||||||
layers_per_pp_stage = len(src_k_ptrs)
|
|
||||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
|
||||||
dst_v_ptrs = dst_kv_ptrs[
|
|
||||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
|
||||||
]
|
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
@@ -296,14 +288,14 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_k_ptrs[layer_id],
|
dst_k_ptrs[layer_id],
|
||||||
kv_item_len,
|
kv_item_len,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_per_pp_stage)
|
for layer_id in range(layers_current_pp_stage)
|
||||||
] + [
|
] + [
|
||||||
(
|
(
|
||||||
src_v_ptrs[layer_id],
|
src_v_ptrs[layer_id],
|
||||||
dst_v_ptrs[layer_id],
|
dst_v_ptrs[layer_id],
|
||||||
kv_item_len,
|
kv_item_len,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_per_pp_stage)
|
for layer_id in range(layers_current_pp_stage)
|
||||||
]
|
]
|
||||||
assert layers_params is not None
|
assert layers_params is not None
|
||||||
|
|
||||||
@@ -401,18 +393,9 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
num_heads_to_send = dst_heads_per_rank
|
num_heads_to_send = dst_heads_per_rank
|
||||||
dst_head_start_offset = 0
|
dst_head_start_offset = 0
|
||||||
|
|
||||||
# pp is not supported on the decode side yet
|
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||||
dst_num_total_layers = num_kv_layers * self.pp_size
|
)
|
||||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
|
||||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
|
||||||
layers_per_pp_stage = len(src_k_ptrs)
|
|
||||||
start_layer = self.pp_rank * layers_per_pp_stage
|
|
||||||
end_layer = start_layer + layers_per_pp_stage
|
|
||||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
|
||||||
dst_v_ptrs = dst_kv_ptrs[
|
|
||||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
|
||||||
]
|
|
||||||
|
|
||||||
# Calculate precise byte offset and length for the sub-slice within the token
|
# Calculate precise byte offset and length for the sub-slice within the token
|
||||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||||
@@ -438,7 +421,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_head_slice_offset,
|
dst_head_slice_offset,
|
||||||
heads_bytes_per_token_to_send,
|
heads_bytes_per_token_to_send,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_per_pp_stage)
|
for layer_id in range(layers_current_pp_stage)
|
||||||
] + [
|
] + [
|
||||||
(
|
(
|
||||||
src_v_ptrs[layer_id],
|
src_v_ptrs[layer_id],
|
||||||
@@ -449,7 +432,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_head_slice_offset,
|
dst_head_slice_offset,
|
||||||
heads_bytes_per_token_to_send,
|
heads_bytes_per_token_to_send,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_per_pp_stage)
|
for layer_id in range(layers_current_pp_stage)
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_layer_tp_aware(layer_params):
|
def process_layer_tp_aware(layer_params):
|
||||||
|
|||||||
Reference in New Issue
Block a user