[PD] feat: mooncake use batch reg/dereg (#8910)

Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
Teng Ma
2025-08-14 00:54:34 +08:00
committed by GitHub
parent a16923efab
commit 4a16a71c36
2 changed files with 39 additions and 8 deletions

View File

@@ -257,15 +257,17 @@ class MooncakeKVManager(BaseKVManager):
)
def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
):
self.engine.register(kv_data_ptr, kv_data_len)
# Batch register KV data buffers
if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
self.engine.batch_register(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
)
for aux_data_ptr, aux_data_len in zip(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
):
self.engine.register(aux_data_ptr, aux_data_len)
# Batch register auxiliary data buffers
if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
self.engine.batch_register(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):

View File

@@ -51,6 +51,35 @@ class MooncakeTransferEngine:
if ret_value != 0:
logger.debug("Mooncake memory deregistration %s failed.", ptr)
def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:
"""Batch register multiple memory regions."""
try:
ret_value = self.engine.batch_register_memory(ptrs, lengths)
except Exception:
# Mark batch register as failed
ret_value = -1
if not hasattr(self.engine, "batch_register_memory"):
raise RuntimeError(
"Mooncake's batch register requires a newer version of mooncake-transfer-engine. "
"Please upgrade Mooncake."
)
if ret_value != 0:
logger.debug("Mooncake batch memory registration failed.")
return ret_value
def batch_deregister(self, ptrs: List[int]) -> int:
"""Batch deregister multiple memory regions."""
try:
ret_value = self.engine.batch_unregister_memory(ptrs)
except Exception:
# Mark batch deregister as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake batch memory deregistration failed.")
return ret_value
def initialize(
self,
hostname: str,