diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index a9e9bf2c5..e345d9519 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -321,67 +321,60 @@ class MooncakeKVManager(BaseKVManager): This may introduce performance overhead (increased TTFT) for long sequences. """ # Extract configuration - local_tp_rank = self.kv_args.engine_rank local_tp_size = self.tp_size // self.dp_size + local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size + src_kv_item_len = self.kv_args.kv_item_lens[0] + dst_tp_rank_in_group = dst_tp_rank % dst_tp_size num_kv_heads = self.kv_args.kv_head_num num_layers = len(self.kv_args.kv_data_ptrs) page_size = self.kv_args.page_size # Calculate head distribution - heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size - heads_per_prefill_rank = num_kv_heads - decode_global_head_start = dst_tp_rank * heads_per_decode_rank - prefill_global_head_start = local_tp_rank * heads_per_prefill_rank - bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size - - decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)] + src_heads_per_rank = num_kv_heads + dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size + bytes_per_head_slice_to_send = ( + dst_kv_item_len // page_size // dst_heads_per_rank + ) # Determine slicing parameters based on TP configuration if local_tp_size > dst_tp_size: - src_head_offset = 0 - num_heads_to_send = heads_per_prefill_rank - dst_head_offset = prefill_global_head_start - decode_global_head_start + # Send KVCache from multiple prefill instances to 1 decode instance + src_head_start_offset = 0 + num_heads_to_send = src_heads_per_rank + dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank else: - src_head_offset = decode_global_head_start - prefill_global_head_start - num_heads_to_send = heads_per_decode_rank - dst_head_offset = 0 + # Send KVCache from 1 prefill instance to multiple decode instances + src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank + num_heads_to_send = dst_heads_per_rank + dst_head_start_offset = 0 - layer_transfer_params = [] + layers_params = [] for layer_id in range(num_layers): - item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id] + # Calculate precise byte offset and length for the sub-slice within the token + src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send + dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send + heads_bytes_per_token_to_send = ( + num_heads_to_send * bytes_per_head_slice_to_send + ) - # Page stride on the target dst decode rank for its slice pages - item_len_of_decode_rank_page = decode_rank_item_lens[layer_id] - - if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0: - logger.error( - f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}" - ) - return -1 - - # Calculate precise byte offset and length for the sub-slice within the prefill page data - src_slice_offset = src_head_offset * bytes_per_head - dst_slice_offset = dst_head_offset * bytes_per_head - slice_lens_per_page = num_heads_to_send * bytes_per_head - - # Sanity check: The data sub-slice to be sent should fit into the decode instance's page. - # This means slice_lens_per_page <= item_len_of_decode_rank_page - if slice_lens_per_page > item_len_of_decode_rank_page: + # Sanity check: The data sub-slice to be sent should fit into the dst buffer. + # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size) + if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size): logger.error( f"[{mooncake_session_id}] Layer {layer_id}: " - f"slice size ({slice_lens_per_page}) exceeds " - f"target page size ({item_len_of_decode_rank_page})" + f"slice size ({heads_bytes_per_token_to_send}) exceeds " + f"target token slot size ({dst_kv_item_len // page_size})" ) return -1 - layer_transfer_params.append( + layers_params.append( ( self.kv_args.kv_data_ptrs[layer_id], dst_kv_ptrs[layer_id], - item_len_of_prefill_rank_page, - item_len_of_decode_rank_page, - src_slice_offset, - dst_slice_offset, - slice_lens_per_page, + src_kv_item_len, + dst_kv_item_len, + src_head_slice_offset, + dst_head_slice_offset, + heads_bytes_per_token_to_send, ) ) @@ -391,9 +384,9 @@ class MooncakeKVManager(BaseKVManager): dst_ptr, src_item_len, dst_item_len, - src_offset, - dst_offset, - slice_lens_per_page, + src_head_slice_offset, + dst_head_slice_offset, + heads_bytes_per_token_to_send, ) = layer_params src_addr_list = [] dst_addr_list = [] @@ -424,17 +417,12 @@ class MooncakeKVManager(BaseKVManager): ) # Calculate final src and dst addresses by applying head-slice offsets - src_slice_addr = src_token_slot_start_addr + src_offset - dst_slice_addr = dst_token_slot_start_addr + dst_offset + src_slice_addr = src_token_slot_start_addr + src_head_slice_offset + dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset src_addr_list.append(src_slice_addr) dst_addr_list.append(dst_slice_addr) - length_list.append(slice_lens_per_page) - - logger.debug( - f"SYNC: sid={mooncake_session_id}, " - f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}" - ) + length_list.append(heads_bytes_per_token_to_send) return self.engine.batch_transfer_sync( mooncake_session_id, src_addr_list, dst_addr_list, length_list @@ -445,7 +433,7 @@ class MooncakeKVManager(BaseKVManager): process_layer_tp_aware, layer_params, ) - for layer_params in layer_transfer_params + for layer_params in layers_params ] for future in concurrent.futures.as_completed(futures):