From 4a16a71c36e5a43b831e4999537e02bdb6ea78fe Mon Sep 17 00:00:00 2001 From: Teng Ma Date: Thu, 14 Aug 2025 00:54:34 +0800 Subject: [PATCH] [PD] feat: mooncake use batch reg/dereg (#8910) Co-authored-by: Shangming Cai --- .../srt/disaggregation/mooncake/conn.py | 18 +++++++----- .../mooncake/transfer_engine.py | 29 +++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 43c462683..9e35078e7 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -257,15 +257,17 @@ class MooncakeKVManager(BaseKVManager): ) def register_buffer_to_engine(self): - for kv_data_ptr, kv_data_len in zip( - self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens - ): - self.engine.register(kv_data_ptr, kv_data_len) + # Batch register KV data buffers + if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens: + self.engine.batch_register( + self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens + ) - for aux_data_ptr, aux_data_len in zip( - self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens - ): - self.engine.register(aux_data_ptr, aux_data_len) + # Batch register auxiliary data buffers + if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens: + self.engine.batch_register( + self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens + ) @cache def _connect(self, endpoint: str, is_ipv6: bool = False): diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 5baee5397..54657bb46 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -51,6 +51,35 @@ class MooncakeTransferEngine: if ret_value != 0: logger.debug("Mooncake memory deregistration %s failed.", ptr) + def batch_register(self, ptrs: List[int], lengths: List[int]) -> int: + """Batch register multiple memory regions.""" + try: + ret_value = self.engine.batch_register_memory(ptrs, lengths) + except Exception: + # Mark batch register as failed + ret_value = -1 + if not hasattr(self.engine, "batch_register_memory"): + raise RuntimeError( + "Mooncake's batch register requires a newer version of mooncake-transfer-engine. " + "Please upgrade Mooncake." + ) + + if ret_value != 0: + logger.debug("Mooncake batch memory registration failed.") + return ret_value + + def batch_deregister(self, ptrs: List[int]) -> int: + """Batch deregister multiple memory regions.""" + try: + ret_value = self.engine.batch_unregister_memory(ptrs) + except Exception: + # Mark batch deregister as failed + ret_value = -1 + + if ret_value != 0: + logger.debug("Mooncake batch memory deregistration failed.") + return ret_value + def initialize( self, hostname: str,