diff --git a/docs/backend/pd_disaggregation.md b/docs/backend/pd_disaggregation.md index 833f0b3f9..c0f8e34d1 100644 --- a/docs/backend/pd_disaggregation.md +++ b/docs/backend/pd_disaggregation.md @@ -56,7 +56,7 @@ PD Disaggregation with Mooncake supports the following environment variables for |:--------:|:-----------:|:--------: | **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions | | **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` | -| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `30` | +| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `120` | #### Decode Server Configuration | Variable | Description | Default | diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 3c05de5e0..8300a71fb 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -187,7 +187,7 @@ class MooncakeKVManager(BaseKVManager): ).start() self.bootstrap_time_out = get_int_env_var( - "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30 + "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120 ) elif self.disaggregation_mode == DisaggregationMode.DECODE: self.heartbeat_failures = {} @@ -195,8 +195,8 @@ class MooncakeKVManager(BaseKVManager): self.session_pool_lock = threading.Lock() self.addr_to_rooms_tracker = defaultdict(set) self.connection_lock = threading.Lock() - self.required_prefill_info_num_map: Dict[int, int] = {} - self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set) + self.required_prefill_response_num_table: Dict[int, int] = {} + self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set) # Heartbeat interval should be at least 2 seconds self.heartbeat_interval = max( float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 @@ -311,22 +311,23 @@ class MooncakeKVManager(BaseKVManager): each page to ensure correctness for any page_size and head-slicing configuration. This may introduce performance overhead (increased TTFT) for long sequences. """ - # rank/kv_head config + # Extract configuration local_tp_rank = self.kv_args.engine_rank local_tp_size = self.tp_size // self.dp_size num_kv_heads = self.kv_args.kv_head_num num_layers = len(self.kv_args.kv_data_ptrs) page_size = self.kv_args.page_size + # Calculate head distribution heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size heads_per_prefill_rank = num_kv_heads decode_global_head_start = dst_tp_rank * heads_per_decode_rank prefill_global_head_start = local_tp_rank * heads_per_prefill_rank bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size - # decode config decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)] + # Determine slicing parameters based on TP configuration if local_tp_size > dst_tp_size: src_head_offset = 0 num_heads_to_send = heads_per_prefill_rank @@ -340,7 +341,7 @@ class MooncakeKVManager(BaseKVManager): for layer_id in range(num_layers): item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id] - # Page stride on the target Decode rank for its slice pages + # Page stride on the target dst decode rank for its slice pages item_len_of_decode_rank_page = decode_rank_item_lens[layer_id] if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0: @@ -349,12 +350,12 @@ class MooncakeKVManager(BaseKVManager): ) return -1 - # Calculate precise byte offset and length for the sub-slice within Prefill page data + # Calculate precise byte offset and length for the sub-slice within the prefill page data src_slice_offset = src_head_offset * bytes_per_head dst_slice_offset = dst_head_offset * bytes_per_head slice_lens_per_page = num_heads_to_send * bytes_per_head - # Sanity check: The data sub-slice we intend to send should fit into D_n's page. + # Sanity check: The data sub-slice to be sent should fit into the decode instance's page. # This means slice_lens_per_page <= item_len_of_decode_rank_page if slice_lens_per_page > item_len_of_decode_rank_page: logger.error( @@ -365,15 +366,13 @@ class MooncakeKVManager(BaseKVManager): return -1 layer_transfer_params.append( ( - self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads) - dst_kv_ptrs[ - layer_id - ], # Decode base ptr (for its slice for this layer) - item_len_of_prefill_rank_page, # Prefill page size (all heads)2048 - item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024 - src_slice_offset, # Offset to slice data in Prefill page - dst_slice_offset, # Offset to slice data in Decode page - slice_lens_per_page, # Length of slice data per page (actual data to send) + self.kv_args.kv_data_ptrs[layer_id], + dst_kv_ptrs[layer_id], + item_len_of_prefill_rank_page, + item_len_of_decode_rank_page, + src_slice_offset, + dst_slice_offset, + slice_lens_per_page, ) ) @@ -399,13 +398,13 @@ class MooncakeKVManager(BaseKVManager): prefill_page_idx = int(prefill_kv_indices[i]) decode_page_idx = int(dst_kv_indices[i]) - # Get the starting memory address for the current source and destination pages + # Get the starting addresses for the current src and dst pages src_page_start_addr = src_ptr + prefill_page_idx * src_item_len dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len # Iterate through each valid token slot within the current page for token_slot_in_page in range(page_size): - # Calculate start address of the current token slot + # Calculate the start address of the current token slot src_token_slot_start_addr = ( src_page_start_addr + token_slot_in_page * bytes_per_token_on_prefill @@ -415,7 +414,7 @@ class MooncakeKVManager(BaseKVManager): + token_slot_in_page * bytes_per_token_on_decode ) - # Calculate final source and destination addresses by applying head-slice offsets + # Calculate final src and dst addresses by applying head-slice offsets src_slice_addr = src_token_slot_start_addr + src_offset dst_slice_addr = dst_token_slot_start_addr + dst_offset @@ -585,9 +584,7 @@ class MooncakeKVManager(BaseKVManager): ret = self.send_aux( req.mooncake_session_id, kv_chunk.prefill_aux_index, - self.decode_kv_args_table[ - req.mooncake_session_id - ].dst_aux_ptrs, + target_rank_registration_info.dst_aux_ptrs, req.dst_aux_index, ) polls.append(True if ret == 0 else False) @@ -675,19 +672,19 @@ class MooncakeKVManager(BaseKVManager): prefill_rank = int(prefill_rank.decode("ascii")) if status == KVPoll.Success: - # record arrived prefill_rank - self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank) - expected_prefill_num = self.required_prefill_info_num_map[ - bootstrap_room - ] - arrived_prefill_num = len( - self.decode_kv_arrive_state[bootstrap_room] - ) - if ( - self.is_mla_backend - or arrived_prefill_num == expected_prefill_num - ): - self.update_status(bootstrap_room, KVPoll.Success) + if bootstrap_room in self.request_status: + self.prefill_response_tracker[bootstrap_room].add(prefill_rank) + expected_response_num = ( + self.required_prefill_response_num_table[bootstrap_room] + ) + arrived_response_num = len( + self.prefill_response_tracker[bootstrap_room] + ) + if ( + self.is_mla_backend + or arrived_response_num == expected_response_num + ): + self.update_status(bootstrap_room, KVPoll.Success) elif status == KVPoll.Failed: self.record_failure( bootstrap_room, @@ -900,14 +897,13 @@ class MooncakeKVSender(BaseKVSender): self.aux_index = None self.bootstrap_server_url = bootstrap_addr self.conclude_state = None - self.init_time = None + self.init_time = time.time() # inner state self.curr_idx = 0 def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices self.aux_index = aux_index - self.init_time = time.time() def send( self, @@ -1031,7 +1027,7 @@ class MooncakeKVReceiver(BaseKVReceiver): self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank ) self.required_dst_info_num = 1 - self.required_prefill_info_num = 1 + self.required_prefill_response_num = 1 self.target_tp_ranks = [self.target_tp_rank] elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: if not self.kv_mgr.is_mla_backend: @@ -1044,7 +1040,7 @@ class MooncakeKVReceiver(BaseKVReceiver): self.required_dst_info_num = ( local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank ) - self.required_prefill_info_num = 1 + self.required_prefill_response_num = 1 self.target_tp_ranks = [self.target_tp_rank] else: if not self.kv_mgr.is_mla_backend: @@ -1067,7 +1063,7 @@ class MooncakeKVReceiver(BaseKVReceiver): # or the KVPoll will never be set correctly self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.required_prefill_info_num = ( + self.required_prefill_response_num = ( prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank ) @@ -1077,8 +1073,8 @@ class MooncakeKVReceiver(BaseKVReceiver): else: self.target_dp_group = bootstrap_room % self.prefill_dp_size - self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = ( - self.required_prefill_info_num + self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( + self.required_prefill_response_num ) # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( @@ -1094,13 +1090,13 @@ class MooncakeKVReceiver(BaseKVReceiver): ) if bootstrap_info is not None: if self.kv_mgr.is_mla_backend: - # MLA :select one prefill rank as real rank + # For MLA: target_tp_rank is the selected real rank, others are dummy ranks bootstrap_info["is_dummy"] = not bool( target_tp_rank == self.target_tp_rank or self.target_tp_rank is None ) else: - # no-MLA:select all prefill ranks + # For non-MLA: all target_tp_ranks are selected real ranks bootstrap_info["is_dummy"] = False logger.debug( f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" @@ -1240,8 +1236,12 @@ class MooncakeKVReceiver(BaseKVReceiver): def clear(self) -> None: if self.bootstrap_room in self.kv_mgr.request_status: self.kv_mgr.request_status.pop(self.bootstrap_room) - self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room) - self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room) + + if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table: + self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room) + + if self.bootstrap_room in self.kv_mgr.prefill_response_tracker: + self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room) def failure_exception(self): # Explicitly set the status to failure since this request has failed in another rank diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 966f7152c..d72aec63f 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -97,13 +97,19 @@ class MooncakeTransferEngine: peer_buffer_addresses: List[int], lengths: List[int], ) -> int: - """Synchronously transfer data to the specified address.""" + """Synchronously transfer data to the specified addresses in batches.""" try: ret = self.engine.batch_transfer_sync_write( session_id, buffers, peer_buffer_addresses, lengths ) except Exception: ret = -1 + # Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2 + if not hasattr(self.engine, "batch_transfer_sync_write"): + raise RuntimeError( + "Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. " + "Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'" + ) if ret < 0: logger.debug(