From 1b9cea5ade6fe34bfafcff24b177e1ae6f5cb14f Mon Sep 17 00:00:00 2001 From: Stepan Kargaltsev Date: Fri, 25 Jul 2025 18:53:30 +0300 Subject: [PATCH] [P/D] Support ipv6 in P/D scenario (#7858) Co-authored-by: Shangming Cai --- .../sglang/srt/disaggregation/common/conn.py | 40 +++++++++-- python/sglang/srt/disaggregation/mini_lb.py | 5 +- .../srt/disaggregation/mooncake/conn.py | 69 +++++++++++++------ .../mooncake/transfer_engine.py | 6 +- python/sglang/srt/disaggregation/nixl/conn.py | 30 ++++---- .../device_communicators/shm_broadcast.py | 17 +++-- python/sglang/srt/utils.py | 10 +++ 7 files changed, 129 insertions(+), 48 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index e6a6ad445..da6cc7217 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import ( ) from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote +from sglang.srt.utils import ( + format_tcp_address, + get_free_port, + get_ip, + get_local_ip_by_remote, + is_valid_ipv6_address, + maybe_wrap_ipv6_address, +) logger = logging.getLogger(__name__) @@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager): def _register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: - ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) + if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] + if self.dist_init_addr.endswith("]"): + host = self.dist_init_addr + else: + host, _ = self.dist_init_addr.rsplit(":", 1) + else: + host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) else: - ip_address = get_ip() + host = get_ip() + host = maybe_wrap_ipv6_address(host) - bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" + bootstrap_server_url = f"{host}:{self.bootstrap_port}" url = f"http://{bootstrap_server_url}/route" payload = { "role": "Prefill", @@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager): logger.error(f"Prefill Failed to register to bootstrap server: {e}") @cache - def _connect(self, endpoint: str): + def _connect(self, endpoint: str, is_ipv6: bool = False): socket = zmq.Context().socket(zmq.PUSH) + if is_ipv6: + socket.setsockopt(zmq.IPV6, 1) socket.connect(endpoint) return socket @@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver): return None @classmethod - def _connect(cls, endpoint: str): + def _connect(cls, endpoint: str, is_ipv6: bool = False): with cls._global_lock: if endpoint not in cls._socket_cache: sock = cls._ctx.socket(zmq.PUSH) + if is_ipv6: + sock.setsockopt(zmq.IPV6, 1) sock.connect(endpoint) cls._socket_cache[endpoint] = sock cls._socket_locks[endpoint] = threading.Lock() return cls._socket_cache[endpoint], cls._socket_locks[endpoint] + @classmethod + def _connect_to_bootstrap_server(cls, bootstrap_info: dict): + ip_address = bootstrap_info["rank_ip"] + port = bootstrap_info["rank_port"] + is_ipv6_address = is_valid_ipv6_address(ip_address) + sock, lock = cls._connect( + format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address + ) + return sock, lock + def _register_kv_args(self): pass diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index d91598e4f..a80407bca 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang.srt.disaggregation.utils import PDRegistryRequest +from sglang.srt.utils import maybe_wrap_ipv6_address AIOHTTP_STREAM_READ_CHUNK_SIZE = ( 1024 * 64 @@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict): # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) - hostname = parsed_url.hostname + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) modified_request = request_data.copy() batch_size = _get_request_batch_size(modified_request) @@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) - hostname = parsed_url.hostname + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) modified_request = request_data.copy() modified_request.update( { diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index e345d9519..c5baa6988 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import ( from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto +from sglang.srt.utils import ( + format_tcp_address, + get_free_port, + get_int_env_var, + get_ip, + get_local_ip_auto, + is_valid_ipv6_address, + maybe_wrap_ipv6_address, +) logger = logging.getLogger(__name__) @@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager): self.request_status: Dict[int, KVPoll] = {} self.rank_port = None self.server_socket = zmq.Context().socket(zmq.PULL) + if is_valid_ipv6_address(self.local_ip): + self.server_socket.setsockopt(zmq.IPV6, 1) + self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} @@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager): self.engine.register(aux_data_ptr, aux_data_len) @cache - def _connect(self, endpoint: str): + def _connect(self, endpoint: str, is_ipv6: bool = False): socket = zmq.Context().socket(zmq.PUSH) + if is_ipv6: + socket.setsockopt(zmq.IPV6, 1) socket.connect(endpoint) return socket @@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager): def sync_status_to_decode_endpoint( self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int ): - if ":" in remote: - remote = remote.split(":")[0] - self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart( + self._connect( + format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote) + ).send_multipart( [ str(room).encode("ascii"), str(status).encode("ascii"), @@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager): f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." ) + def _bind_server_socket(self): + self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) + def start_prefill_thread(self): self.rank_port = get_free_port() - self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}") + self._bind_server_socket() def bootstrap_thread(): """This thread recvs pre-alloc notification from the decode engine""" @@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager): def start_decode_thread(self): self.rank_port = get_free_port() - self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}") + self._bind_server_socket() def decode_thread(): while True: @@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager): # requests with the same dst_sessions will be added into the same # queue, which enables early abort with failed sessions. dst_infos = self.transfer_infos[bootstrap_room].keys() - session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos) + session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos) shard_idx = session_port_sum % len(self.transfer_queues) self.transfer_queues[shard_idx].put( @@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager): def _register_to_bootstrap(self): """Register KVSender to bootstrap server via HTTP POST.""" if self.dist_init_addr: - ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) + if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] + if self.dist_init_addr.endswith("]"): + host = self.dist_init_addr + else: + host, _ = self.dist_init_addr.rsplit(":", 1) + else: + host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) else: - ip_address = get_ip() + host = get_ip() + host = maybe_wrap_ipv6_address(host) - bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" + bootstrap_server_url = f"{host}:{self.bootstrap_port}" url = f"http://{bootstrap_server_url}/route" payload = { "role": "Prefill", @@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver): def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: - self.prefill_server_url = ( - f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" - ) packed_kv_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs ) @@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver): dst_tp_size = str(tp_size).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii") - sock, lock = self._connect("tcp://" + self.prefill_server_url) + sock, lock = self._connect_to_bootstrap_server(bootstrap_info) with lock: sock.send_multipart( [ @@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver): ) @classmethod - def _connect(cls, endpoint: str): + def _connect(cls, endpoint: str, is_ipv6: bool = False): with cls._global_lock: if endpoint not in cls._socket_cache: sock = cls._ctx.socket(zmq.PUSH) + if is_ipv6: + sock.setsockopt(zmq.IPV6, 1) sock.connect(endpoint) cls._socket_cache[endpoint] = sock cls._socket_locks[endpoint] = threading.Lock() return cls._socket_cache[endpoint], cls._socket_locks[endpoint] + @classmethod + def _connect_to_bootstrap_server(cls, bootstrap_info: dict): + ip_address = bootstrap_info["rank_ip"] + port = bootstrap_info["rank_port"] + is_ipv6_address = is_valid_ipv6_address(ip_address) + sock, lock = cls._connect( + format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address + ) + return sock, lock + def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: - self.prefill_server_url = ( - f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" - ) + sock, lock = self._connect_to_bootstrap_server(bootstrap_info) is_dummy = bootstrap_info["is_dummy"] - sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: sock.send_multipart( [ diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 8c7ea0108..5baee5397 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -1,7 +1,7 @@ import logging from typing import List, Optional -from sglang.srt.utils import get_bool_env_var, get_free_port +from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address logger = logging.getLogger(__name__) @@ -27,7 +27,9 @@ class MooncakeTransferEngine: hostname=self.hostname, device_name=self.ib_device, ) - self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" + self.session_id = ( + f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}" + ) def register(self, ptr, length): try: diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 73f32c0a6..7a75d79b7 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import ( from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_local_ip_by_remote +from sglang.srt.utils import ( + format_tcp_address, + get_local_ip_auto, + is_valid_ipv6_address, +) logger = logging.getLogger(__name__) @@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager): "to run SGLang with NixlTransferEngine." ) from e self.agent = nixl_agent(str(uuid.uuid4())) + self.local_ip = get_local_ip_auto() self.server_socket = zmq.Context().socket(zmq.PULL) + if is_valid_ipv6_address(self.local_ip): + self.server_socket.setsockopt(zmq.IPV6, 1) self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager): return False return self.transfer_statuses[room].is_done() + def _bind_server_socket(self): + self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) + def _start_bootstrap_thread(self): - self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") + self._bind_server_socket() def bootstrap_thread(): """This thread recvs transfer info from the decode engine""" @@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: - self.prefill_server_url = ( - f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" - ) logger.debug( f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) + sock, lock = self._connect_to_bootstrap_server(bootstrap_info) is_dummy = bootstrap_info["is_dummy"] logger.debug( - f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}" + f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}" ) - sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: sock.send_multipart( [ GUARD, str(self.bootstrap_room).encode("ascii"), - get_local_ip_by_remote().encode("ascii"), + self.kv_mgr.local_ip.encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"), self.kv_mgr.agent.name.encode("ascii"), kv_indices.tobytes() if not is_dummy else b"", @@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver): def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: - self.prefill_server_url = ( - f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" - ) + sock, lock = self._connect_to_bootstrap_server(bootstrap_info) packed_kv_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs ) @@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver): struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) - sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: sock.send_multipart( [ GUARD, "None".encode("ascii"), - get_local_ip_by_remote().encode("ascii"), + self.kv_mgr.local_ip.encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"), self.kv_mgr.agent.name.encode("ascii"), self.kv_mgr.agent.get_agent_metadata(), diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index 4e5c55a99..e5b59e7cc 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore -from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address +from sglang.srt.utils import ( + format_tcp_address, + get_ip, + get_open_port, + is_valid_ipv6_address, +) # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 SGLANG_RINGBUFFER_WARNING_INTERVAL = int( @@ -225,9 +230,9 @@ class MessageQueue: remote_subscribe_port = get_open_port() if is_valid_ipv6_address(connect_ip): self.remote_socket.setsockopt(IPV6, 1) - connect_ip = f"[{connect_ip}]" - socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" - self.remote_socket.bind(socket_addr) + self.remote_socket.bind( + format_tcp_address(connect_ip, remote_subscribe_port) + ) else: remote_subscribe_port = None @@ -288,7 +293,9 @@ class MessageQueue: self.remote_socket.setsockopt_string(SUBSCRIBE, "") if is_valid_ipv6_address(handle.connect_ip): self.remote_socket.setsockopt(IPV6, 1) - socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + socket_addr = format_tcp_address( + handle.connect_ip, handle.remote_subscribe_port + ) logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 01e54392a..52a1e20b8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2065,6 +2065,16 @@ def is_valid_ipv6_address(address: str) -> bool: return False +def maybe_wrap_ipv6_address(address: str) -> str: + if is_valid_ipv6_address(address): + return f"[{address}]" + return address + + +def format_tcp_address(ip: str, port: int) -> str: + return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}" + + def configure_ipv6(dist_init_addr): addr = dist_init_addr end = addr.find("]")