[PD] Use batch transfer for rdma transport and add notes for mnnvl usage (#8595)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
|||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
format_tcp_address,
|
format_tcp_address,
|
||||||
|
get_bool_env_var,
|
||||||
get_free_port,
|
get_free_port,
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
get_ip,
|
get_ip,
|
||||||
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.bootstrap_timeout = get_int_env_var(
|
self.bootstrap_timeout = get_int_env_var(
|
||||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.enable_custom_mem_pool = get_bool_env_var(
|
||||||
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||||
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.heartbeat_failures = {}
|
self.heartbeat_failures = {}
|
||||||
self.session_pool = defaultdict(requests.Session)
|
self.session_pool = defaultdict(requests.Session)
|
||||||
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
socket.connect(endpoint)
|
socket.connect(endpoint)
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
|
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
||||||
|
if not transfer_blocks:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
||||||
|
if self.enable_custom_mem_pool:
|
||||||
|
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
|
||||||
|
for src_addr, dst_addr, length in transfer_blocks:
|
||||||
|
status = self.engine.transfer_sync(
|
||||||
|
mooncake_session_id, src_addr, dst_addr, length
|
||||||
|
)
|
||||||
|
if status != 0:
|
||||||
|
return status
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
||||||
|
return self.engine.batch_transfer_sync(
|
||||||
|
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||||
|
)
|
||||||
|
|
||||||
def send_kvcache(
|
def send_kvcache(
|
||||||
self,
|
self,
|
||||||
mooncake_session_id: str,
|
mooncake_session_id: str,
|
||||||
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
# Worker function for processing a single layer
|
# Worker function for processing a single layer
|
||||||
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||||
|
transfer_blocks = []
|
||||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||||
length = item_len * len(prefill_index)
|
length = item_len * len(prefill_index)
|
||||||
|
transfer_blocks.append((src_addr, dst_addr, length))
|
||||||
|
|
||||||
status = self.engine.transfer_sync(
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||||
mooncake_session_id, src_addr, dst_addr, length
|
|
||||||
)
|
|
||||||
if status != 0:
|
|
||||||
return status
|
|
||||||
return 0
|
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
dst_aux_ptrs: list[int],
|
dst_aux_ptrs: list[int],
|
||||||
dst_aux_index: int,
|
dst_aux_index: int,
|
||||||
):
|
):
|
||||||
src_addr_list = []
|
transfer_blocks = []
|
||||||
dst_addr_list = []
|
|
||||||
length_list = []
|
|
||||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||||
|
|
||||||
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
||||||
length = prefill_aux_item_lens[i]
|
length = prefill_aux_item_lens[i]
|
||||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||||
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||||
src_addr_list.append(src_addr)
|
transfer_blocks.append((src_addr, dst_addr, length))
|
||||||
dst_addr_list.append(dst_addr)
|
|
||||||
length_list.append(length)
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||||
return self.engine.batch_transfer_sync(
|
|
||||||
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
|
||||||
)
|
|
||||||
|
|
||||||
def sync_status_to_decode_endpoint(
|
def sync_status_to_decode_endpoint(
|
||||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||||
|
|||||||
Reference in New Issue
Block a user