[PD] Use batch transfer for rdma transport and add notes for mnnvl usage (#8595)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_bool_env_var,
|
||||
get_free_port,
|
||||
get_int_env_var,
|
||||
get_ip,
|
||||
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.bootstrap_timeout = get_int_env_var(
|
||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
||||
)
|
||||
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.heartbeat_failures = {}
|
||||
self.session_pool = defaultdict(requests.Session)
|
||||
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
|
||||
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
||||
if not transfer_blocks:
|
||||
return 0
|
||||
|
||||
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
||||
if self.enable_custom_mem_pool:
|
||||
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
|
||||
for src_addr, dst_addr, length in transfer_blocks:
|
||||
status = self.engine.transfer_sync(
|
||||
mooncake_session_id, src_addr, dst_addr, length
|
||||
)
|
||||
if status != 0:
|
||||
return status
|
||||
return 0
|
||||
else:
|
||||
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
||||
return self.engine.batch_transfer_sync(
|
||||
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||
)
|
||||
|
||||
def send_kvcache(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
# Worker function for processing a single layer
|
||||
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||
transfer_blocks = []
|
||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||
length = item_len * len(prefill_index)
|
||||
transfer_blocks.append((src_addr, dst_addr, length))
|
||||
|
||||
status = self.engine.transfer_sync(
|
||||
mooncake_session_id, src_addr, dst_addr, length
|
||||
)
|
||||
if status != 0:
|
||||
return status
|
||||
return 0
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
|
||||
futures = [
|
||||
executor.submit(
|
||||
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
|
||||
dst_aux_ptrs: list[int],
|
||||
dst_aux_index: int,
|
||||
):
|
||||
src_addr_list = []
|
||||
dst_addr_list = []
|
||||
length_list = []
|
||||
transfer_blocks = []
|
||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||
|
||||
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
||||
length = prefill_aux_item_lens[i]
|
||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||
src_addr_list.append(src_addr)
|
||||
dst_addr_list.append(dst_addr)
|
||||
length_list.append(length)
|
||||
return self.engine.batch_transfer_sync(
|
||||
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
||||
)
|
||||
transfer_blocks.append((src_addr, dst_addr, length))
|
||||
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
|
||||
def sync_status_to_decode_endpoint(
|
||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||
|
||||
Reference in New Issue
Block a user