[PD] Support non-MLA models PD different TP with DP attention (#7931)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -321,67 +321,60 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
This may introduce performance overhead (increased TTFT) for long sequences.
|
This may introduce performance overhead (increased TTFT) for long sequences.
|
||||||
"""
|
"""
|
||||||
# Extract configuration
|
# Extract configuration
|
||||||
local_tp_rank = self.kv_args.engine_rank
|
|
||||||
local_tp_size = self.tp_size // self.dp_size
|
local_tp_size = self.tp_size // self.dp_size
|
||||||
|
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
|
||||||
|
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
|
dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
|
||||||
num_kv_heads = self.kv_args.kv_head_num
|
num_kv_heads = self.kv_args.kv_head_num
|
||||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||||
page_size = self.kv_args.page_size
|
page_size = self.kv_args.page_size
|
||||||
|
|
||||||
# Calculate head distribution
|
# Calculate head distribution
|
||||||
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
|
src_heads_per_rank = num_kv_heads
|
||||||
heads_per_prefill_rank = num_kv_heads
|
dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
|
||||||
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
|
bytes_per_head_slice_to_send = (
|
||||||
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
|
dst_kv_item_len // page_size // dst_heads_per_rank
|
||||||
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
|
)
|
||||||
|
|
||||||
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
|
|
||||||
|
|
||||||
# Determine slicing parameters based on TP configuration
|
# Determine slicing parameters based on TP configuration
|
||||||
if local_tp_size > dst_tp_size:
|
if local_tp_size > dst_tp_size:
|
||||||
src_head_offset = 0
|
# Send KVCache from multiple prefill instances to 1 decode instance
|
||||||
num_heads_to_send = heads_per_prefill_rank
|
src_head_start_offset = 0
|
||||||
dst_head_offset = prefill_global_head_start - decode_global_head_start
|
num_heads_to_send = src_heads_per_rank
|
||||||
|
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
||||||
else:
|
else:
|
||||||
src_head_offset = decode_global_head_start - prefill_global_head_start
|
# Send KVCache from 1 prefill instance to multiple decode instances
|
||||||
num_heads_to_send = heads_per_decode_rank
|
src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
|
||||||
dst_head_offset = 0
|
num_heads_to_send = dst_heads_per_rank
|
||||||
|
dst_head_start_offset = 0
|
||||||
|
|
||||||
layer_transfer_params = []
|
layers_params = []
|
||||||
for layer_id in range(num_layers):
|
for layer_id in range(num_layers):
|
||||||
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
|
# Calculate precise byte offset and length for the sub-slice within the token
|
||||||
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||||
|
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
||||||
|
heads_bytes_per_token_to_send = (
|
||||||
|
num_heads_to_send * bytes_per_head_slice_to_send
|
||||||
|
)
|
||||||
|
|
||||||
# Page stride on the target dst decode rank for its slice pages
|
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
|
||||||
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
|
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
|
||||||
|
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
|
||||||
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
|
|
||||||
logger.error(
|
|
||||||
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
|
|
||||||
)
|
|
||||||
return -1
|
|
||||||
|
|
||||||
# Calculate precise byte offset and length for the sub-slice within the prefill page data
|
|
||||||
src_slice_offset = src_head_offset * bytes_per_head
|
|
||||||
dst_slice_offset = dst_head_offset * bytes_per_head
|
|
||||||
slice_lens_per_page = num_heads_to_send * bytes_per_head
|
|
||||||
|
|
||||||
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
|
|
||||||
# This means slice_lens_per_page <= item_len_of_decode_rank_page
|
|
||||||
if slice_lens_per_page > item_len_of_decode_rank_page:
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{mooncake_session_id}] Layer {layer_id}: "
|
f"[{mooncake_session_id}] Layer {layer_id}: "
|
||||||
f"slice size ({slice_lens_per_page}) exceeds "
|
f"slice size ({heads_bytes_per_token_to_send}) exceeds "
|
||||||
f"target page size ({item_len_of_decode_rank_page})"
|
f"target token slot size ({dst_kv_item_len // page_size})"
|
||||||
)
|
)
|
||||||
return -1
|
return -1
|
||||||
layer_transfer_params.append(
|
layers_params.append(
|
||||||
(
|
(
|
||||||
self.kv_args.kv_data_ptrs[layer_id],
|
self.kv_args.kv_data_ptrs[layer_id],
|
||||||
dst_kv_ptrs[layer_id],
|
dst_kv_ptrs[layer_id],
|
||||||
item_len_of_prefill_rank_page,
|
src_kv_item_len,
|
||||||
item_len_of_decode_rank_page,
|
dst_kv_item_len,
|
||||||
src_slice_offset,
|
src_head_slice_offset,
|
||||||
dst_slice_offset,
|
dst_head_slice_offset,
|
||||||
slice_lens_per_page,
|
heads_bytes_per_token_to_send,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -391,9 +384,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
dst_ptr,
|
dst_ptr,
|
||||||
src_item_len,
|
src_item_len,
|
||||||
dst_item_len,
|
dst_item_len,
|
||||||
src_offset,
|
src_head_slice_offset,
|
||||||
dst_offset,
|
dst_head_slice_offset,
|
||||||
slice_lens_per_page,
|
heads_bytes_per_token_to_send,
|
||||||
) = layer_params
|
) = layer_params
|
||||||
src_addr_list = []
|
src_addr_list = []
|
||||||
dst_addr_list = []
|
dst_addr_list = []
|
||||||
@@ -424,17 +417,12 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate final src and dst addresses by applying head-slice offsets
|
# Calculate final src and dst addresses by applying head-slice offsets
|
||||||
src_slice_addr = src_token_slot_start_addr + src_offset
|
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
||||||
dst_slice_addr = dst_token_slot_start_addr + dst_offset
|
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
||||||
|
|
||||||
src_addr_list.append(src_slice_addr)
|
src_addr_list.append(src_slice_addr)
|
||||||
dst_addr_list.append(dst_slice_addr)
|
dst_addr_list.append(dst_slice_addr)
|
||||||
length_list.append(slice_lens_per_page)
|
length_list.append(heads_bytes_per_token_to_send)
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"SYNC: sid={mooncake_session_id}, "
|
|
||||||
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.engine.batch_transfer_sync(
|
return self.engine.batch_transfer_sync(
|
||||||
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
||||||
@@ -445,7 +433,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
process_layer_tp_aware,
|
process_layer_tp_aware,
|
||||||
layer_params,
|
layer_params,
|
||||||
)
|
)
|
||||||
for layer_params in layer_transfer_params
|
for layer_params in layers_params
|
||||||
]
|
]
|
||||||
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
|||||||
Reference in New Issue
Block a user