From 016fd2512788f029d6d9a7c866c12dadb34cde9c Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 31 Jul 2025 21:29:34 +0800 Subject: [PATCH] [PD] Use batch transfer for rdma transport and add notes for mnnvl usage (#8595) Signed-off-by: Shangming Cai --- .../srt/disaggregation/mooncake/conn.py | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index bb0b47471..d366b2791 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, + get_bool_env_var, get_free_port, get_int_env_var, get_ip, @@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager): self.bootstrap_timeout = get_int_env_var( "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300 ) + + self.enable_custom_mem_pool = get_bool_env_var( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" + ) elif self.disaggregation_mode == DisaggregationMode.DECODE: self.heartbeat_failures = {} self.session_pool = defaultdict(requests.Session) @@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager): socket.connect(endpoint) return socket + def _transfer_data(self, mooncake_session_id, transfer_blocks): + if not transfer_blocks: + return 0 + + # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free + if self.enable_custom_mem_pool: + # batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily + for src_addr, dst_addr, length in transfer_blocks: + status = self.engine.transfer_sync( + mooncake_session_id, src_addr, dst_addr, length + ) + if status != 0: + return status + return 0 + else: + src_addrs, dst_addrs, lengths = zip(*transfer_blocks) + return self.engine.batch_transfer_sync( + mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths) + ) + def send_kvcache( self, mooncake_session_id: str, @@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager): # Worker function for processing a single layer def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: + transfer_blocks = [] for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): src_addr = src_ptr + int(prefill_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len length = item_len * len(prefill_index) + transfer_blocks.append((src_addr, dst_addr, length)) - status = self.engine.transfer_sync( - mooncake_session_id, src_addr, dst_addr, length - ) - if status != 0: - return status - return 0 + return self._transfer_data(mooncake_session_id, transfer_blocks) futures = [ executor.submit( @@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager): dst_aux_ptrs: list[int], dst_aux_index: int, ): - src_addr_list = [] - dst_addr_list = [] - length_list = [] + transfer_blocks = [] prefill_aux_ptrs = self.kv_args.aux_data_ptrs prefill_aux_item_lens = self.kv_args.aux_item_lens + for i, dst_aux_ptr in enumerate(dst_aux_ptrs): length = prefill_aux_item_lens[i] src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index dst_addr = dst_aux_ptrs[i] + length * dst_aux_index - src_addr_list.append(src_addr) - dst_addr_list.append(dst_addr) - length_list.append(length) - return self.engine.batch_transfer_sync( - mooncake_session_id, src_addr_list, dst_addr_list, length_list - ) + transfer_blocks.append((src_addr, dst_addr, length)) + + return self._transfer_data(mooncake_session_id, transfer_blocks) def sync_status_to_decode_endpoint( self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int