diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 29e861e9f..92e182dfd 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -251,17 +251,19 @@ class MooncakeKVManager(BaseKVManager): # Worker function for processing a single layer def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: + src_addr_list = [] + dst_addr_list = [] + length_list = [] for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): src_addr = src_ptr + int(prefill_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len length = item_len * len(prefill_index) - - status = self.engine.transfer_sync( - mooncake_session_id, src_addr, dst_addr, length - ) - if status != 0: - return status - return 0 + src_addr_list.append(src_addr) + dst_addr_list.append(dst_addr) + length_list.append(length) + return self.engine.batch_transfer_sync( + mooncake_session_id, src_addr_list, dst_addr_list, length_list + ) futures = [ executor.submit( diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 5643af70b..966f7152c 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -1,7 +1,7 @@ import json import logging from dataclasses import dataclass -from typing import Optional +from typing import List, Optional logger = logging.getLogger(__name__) @@ -90,5 +90,29 @@ class MooncakeTransferEngine: return ret + def batch_transfer_sync( + self, + session_id: str, + buffers: List[int], + peer_buffer_addresses: List[int], + lengths: List[int], + ) -> int: + """Synchronously transfer data to the specified address.""" + try: + ret = self.engine.batch_transfer_sync_write( + session_id, buffers, peer_buffer_addresses, lengths + ) + except Exception: + ret = -1 + + if ret < 0: + logger.debug( + "Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s", + buffers, + session_id, + peer_buffer_addresses, + ) + return ret + def get_session_id(self): return self.session_id