[P/D] Support ipv6 in P/D scenario (#7858)
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
committed by
GitHub
parent
9045cc1eb8
commit
1b9cea5ade
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_free_port,
|
||||
get_ip,
|
||||
get_local_ip_by_remote,
|
||||
is_valid_ipv6_address,
|
||||
maybe_wrap_ipv6_address,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
|
||||
def _register_to_bootstrap(self):
|
||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||
if self.dist_init_addr:
|
||||
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
||||
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||
if self.dist_init_addr.endswith("]"):
|
||||
host = self.dist_init_addr
|
||||
else:
|
||||
host, _ = self.dist_init_addr.rsplit(":", 1)
|
||||
else:
|
||||
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||
else:
|
||||
ip_address = get_ip()
|
||||
host = get_ip()
|
||||
host = maybe_wrap_ipv6_address(host)
|
||||
|
||||
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
||||
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||
url = f"http://{bootstrap_server_url}/route"
|
||||
payload = {
|
||||
"role": "Prefill",
|
||||
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
|
||||
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
||||
|
||||
@cache
|
||||
def _connect(self, endpoint: str):
|
||||
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||
socket = zmq.Context().socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
|
||||
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _connect(cls, endpoint: str):
|
||||
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
||||
with cls._global_lock:
|
||||
if endpoint not in cls._socket_cache:
|
||||
sock = cls._ctx.socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
sock.setsockopt(zmq.IPV6, 1)
|
||||
sock.connect(endpoint)
|
||||
cls._socket_cache[endpoint] = sock
|
||||
cls._socket_locks[endpoint] = threading.Lock()
|
||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||
|
||||
@classmethod
|
||||
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
||||
ip_address = bootstrap_info["rank_ip"]
|
||||
port = bootstrap_info["rank_port"]
|
||||
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
||||
sock, lock = cls._connect(
|
||||
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
||||
)
|
||||
return sock, lock
|
||||
|
||||
def _register_kv_args(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
||||
from sglang.srt.utils import maybe_wrap_ipv6_address
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||
1024 * 64
|
||||
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
|
||||
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_free_port,
|
||||
get_int_env_var,
|
||||
get_ip,
|
||||
get_local_ip_auto,
|
||||
is_valid_ipv6_address,
|
||||
maybe_wrap_ipv6_address,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.request_status: Dict[int, KVPoll] = {}
|
||||
self.rank_port = None
|
||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||
if is_valid_ipv6_address(self.local_ip):
|
||||
self.server_socket.setsockopt(zmq.IPV6, 1)
|
||||
|
||||
self.register_buffer_to_engine()
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.engine.register(aux_data_ptr, aux_data_len)
|
||||
|
||||
@cache
|
||||
def _connect(self, endpoint: str):
|
||||
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||
socket = zmq.Context().socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
|
||||
@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
def sync_status_to_decode_endpoint(
|
||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||
):
|
||||
if ":" in remote:
|
||||
remote = remote.split(":")[0]
|
||||
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
||||
self._connect(
|
||||
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
||||
).send_multipart(
|
||||
[
|
||||
str(room).encode("ascii"),
|
||||
str(status).encode("ascii"),
|
||||
@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
|
||||
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
||||
)
|
||||
|
||||
def _bind_server_socket(self):
|
||||
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||
|
||||
def start_prefill_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
||||
self._bind_server_socket()
|
||||
|
||||
def bootstrap_thread():
|
||||
"""This thread recvs pre-alloc notification from the decode engine"""
|
||||
@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
def start_decode_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
||||
self._bind_server_socket()
|
||||
|
||||
def decode_thread():
|
||||
while True:
|
||||
@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
# requests with the same dst_sessions will be added into the same
|
||||
# queue, which enables early abort with failed sessions.
|
||||
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
||||
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
|
||||
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
|
||||
shard_idx = session_port_sum % len(self.transfer_queues)
|
||||
|
||||
self.transfer_queues[shard_idx].put(
|
||||
@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
|
||||
def _register_to_bootstrap(self):
|
||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||
if self.dist_init_addr:
|
||||
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
||||
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||
if self.dist_init_addr.endswith("]"):
|
||||
host = self.dist_init_addr
|
||||
else:
|
||||
host, _ = self.dist_init_addr.rsplit(":", 1)
|
||||
else:
|
||||
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||
else:
|
||||
ip_address = get_ip()
|
||||
host = get_ip()
|
||||
host = maybe_wrap_ipv6_address(host)
|
||||
|
||||
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
||||
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||
url = f"http://{bootstrap_server_url}/route"
|
||||
payload = {
|
||||
"role": "Prefill",
|
||||
@@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
|
||||
def _register_kv_args(self):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
)
|
||||
packed_kv_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||
)
|
||||
@@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
dst_tp_size = str(tp_size).encode("ascii")
|
||||
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
||||
|
||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
@@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _connect(cls, endpoint: str):
|
||||
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
||||
with cls._global_lock:
|
||||
if endpoint not in cls._socket_cache:
|
||||
sock = cls._ctx.socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
sock.setsockopt(zmq.IPV6, 1)
|
||||
sock.connect(endpoint)
|
||||
cls._socket_cache[endpoint] = sock
|
||||
cls._socket_locks[endpoint] = threading.Lock()
|
||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||
|
||||
@classmethod
|
||||
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
||||
ip_address = bootstrap_info["rank_ip"]
|
||||
port = bootstrap_info["rank_port"]
|
||||
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
||||
sock, lock = cls._connect(
|
||||
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
||||
)
|
||||
return sock, lock
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
)
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
is_dummy = bootstrap_info["is_dummy"]
|
||||
|
||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_free_port
|
||||
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
|
||||
hostname=self.hostname,
|
||||
device_name=self.ib_device,
|
||||
)
|
||||
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
||||
self.session_id = (
|
||||
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
|
||||
)
|
||||
|
||||
def register(self, ptr, length):
|
||||
try:
|
||||
|
||||
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
|
||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_local_ip_by_remote
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_local_ip_auto,
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
|
||||
"to run SGLang with NixlTransferEngine."
|
||||
) from e
|
||||
self.agent = nixl_agent(str(uuid.uuid4()))
|
||||
self.local_ip = get_local_ip_auto()
|
||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||
if is_valid_ipv6_address(self.local_ip):
|
||||
self.server_socket.setsockopt(zmq.IPV6, 1)
|
||||
self.register_buffer_to_engine()
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
|
||||
return False
|
||||
return self.transfer_statuses[room].is_done()
|
||||
|
||||
def _bind_server_socket(self):
|
||||
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||
|
||||
def _start_bootstrap_thread(self):
|
||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
||||
self._bind_server_socket()
|
||||
|
||||
def bootstrap_thread():
|
||||
"""This thread recvs transfer info from the decode engine"""
|
||||
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||
)
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
is_dummy = bootstrap_info["is_dummy"]
|
||||
logger.debug(
|
||||
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
||||
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
||||
)
|
||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
GUARD,
|
||||
str(self.bootstrap_room).encode("ascii"),
|
||||
get_local_ip_by_remote().encode("ascii"),
|
||||
self.kv_mgr.local_ip.encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.kv_mgr.agent.name.encode("ascii"),
|
||||
kv_indices.tobytes() if not is_dummy else b"",
|
||||
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
|
||||
def _register_kv_args(self):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
)
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
packed_kv_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||
)
|
||||
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||
)
|
||||
|
||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
GUARD,
|
||||
"None".encode("ascii"),
|
||||
get_local_ip_by_remote().encode("ascii"),
|
||||
self.kv_mgr.local_ip.encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.kv_mgr.agent.name.encode("ascii"),
|
||||
self.kv_mgr.agent.get_agent_metadata(),
|
||||
|
||||
@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
|
||||
from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
||||
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
||||
@@ -225,9 +230,9 @@ class MessageQueue:
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
self.remote_socket.bind(
|
||||
format_tcp_address(connect_ip, remote_subscribe_port)
|
||||
)
|
||||
|
||||
else:
|
||||
remote_subscribe_port = None
|
||||
@@ -288,7 +293,9 @@ class MessageQueue:
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
if is_valid_ipv6_address(handle.connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
||||
socket_addr = format_tcp_address(
|
||||
handle.connect_ip, handle.remote_subscribe_port
|
||||
)
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
|
||||
|
||||
@@ -2065,6 +2065,16 @@ def is_valid_ipv6_address(address: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||
if is_valid_ipv6_address(address):
|
||||
return f"[{address}]"
|
||||
return address
|
||||
|
||||
|
||||
def format_tcp_address(ip: str, port: int) -> str:
|
||||
return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
|
||||
|
||||
|
||||
def configure_ipv6(dist_init_addr):
|
||||
addr = dist_init_addr
|
||||
end = addr.find("]")
|
||||
|
||||
Reference in New Issue
Block a user