From a16923efaba36c4c813f710cdbccaab1da42e51f Mon Sep 17 00:00:00 2001 From: Francis <38564764+ssssnow@users.noreply.github.com> Date: Thu, 14 Aug 2025 00:54:14 +0800 Subject: [PATCH] [PD] optimize kv cache transfer directly using batch transfer (#9149) Co-authored-by: Shangming Cai --- .../srt/disaggregation/mooncake/conn.py | 50 ++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 25188c6a8..43c462683 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -356,33 +356,49 @@ class MooncakeKVManager(BaseKVManager): ] assert layers_params is not None - # Worker function for processing a single layer - def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: + def set_transfer_blocks( + src_ptr: int, dst_ptr: int, item_len: int + ) -> List[Tuple[int, int, int]]: transfer_blocks = [] for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): src_addr = src_ptr + int(prefill_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len length = item_len * len(prefill_index) transfer_blocks.append((src_addr, dst_addr, length)) + return transfer_blocks + # Worker function for processing a single layer + def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: + transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len) return self._transfer_data(mooncake_session_id, transfer_blocks) - futures = [ - executor.submit( - process_layer, - src_ptr, - dst_ptr, - item_len, - ) - for (src_ptr, dst_ptr, item_len) in layers_params - ] + # Worker function for processing all layers in a batch + def process_layers(layers_params: List[Tuple[int, int, int]]) -> int: + transfer_blocks = [] + for src_ptr, dst_ptr, item_len in layers_params: + transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len)) + return self._transfer_data(mooncake_session_id, transfer_blocks) - for future in concurrent.futures.as_completed(futures): - status = future.result() - if status != 0: - for f in futures: - f.cancel() - return status + if self.enable_custom_mem_pool: + futures = [ + executor.submit( + process_layer, + src_ptr, + dst_ptr, + item_len, + ) + for (src_ptr, dst_ptr, item_len) in layers_params + ] + for future in concurrent.futures.as_completed(futures): + status = future.result() + if status != 0: + for f in futures: + f.cancel() + return status + else: + # Combining all layers' params in one batch transfer is more efficient + # compared to using multiple threads + return process_layers(layers_params) return 0