[P/D] Support ipv6 in P/D scenario (#7858)
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
committed by
GitHub
parent
9045cc1eb8
commit
1b9cea5ade
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
from sglang.srt.utils import (
|
||||||
|
format_tcp_address,
|
||||||
|
get_free_port,
|
||||||
|
get_ip,
|
||||||
|
get_local_ip_by_remote,
|
||||||
|
is_valid_ipv6_address,
|
||||||
|
maybe_wrap_ipv6_address,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
|
|||||||
def _register_to_bootstrap(self):
|
def _register_to_bootstrap(self):
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
if self.dist_init_addr:
|
if self.dist_init_addr:
|
||||||
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||||
|
if self.dist_init_addr.endswith("]"):
|
||||||
|
host = self.dist_init_addr
|
||||||
|
else:
|
||||||
|
host, _ = self.dist_init_addr.rsplit(":", 1)
|
||||||
|
else:
|
||||||
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||||
else:
|
else:
|
||||||
ip_address = get_ip()
|
host = get_ip()
|
||||||
|
host = maybe_wrap_ipv6_address(host)
|
||||||
|
|
||||||
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||||
url = f"http://{bootstrap_server_url}/route"
|
url = f"http://{bootstrap_server_url}/route"
|
||||||
payload = {
|
payload = {
|
||||||
"role": "Prefill",
|
"role": "Prefill",
|
||||||
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
|
|||||||
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _connect(self, endpoint: str):
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||||
socket = zmq.Context().socket(zmq.PUSH)
|
socket = zmq.Context().socket(zmq.PUSH)
|
||||||
|
if is_ipv6:
|
||||||
|
socket.setsockopt(zmq.IPV6, 1)
|
||||||
socket.connect(endpoint)
|
socket.connect(endpoint)
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _connect(cls, endpoint: str):
|
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
||||||
with cls._global_lock:
|
with cls._global_lock:
|
||||||
if endpoint not in cls._socket_cache:
|
if endpoint not in cls._socket_cache:
|
||||||
sock = cls._ctx.socket(zmq.PUSH)
|
sock = cls._ctx.socket(zmq.PUSH)
|
||||||
|
if is_ipv6:
|
||||||
|
sock.setsockopt(zmq.IPV6, 1)
|
||||||
sock.connect(endpoint)
|
sock.connect(endpoint)
|
||||||
cls._socket_cache[endpoint] = sock
|
cls._socket_cache[endpoint] = sock
|
||||||
cls._socket_locks[endpoint] = threading.Lock()
|
cls._socket_locks[endpoint] = threading.Lock()
|
||||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
||||||
|
ip_address = bootstrap_info["rank_ip"]
|
||||||
|
port = bootstrap_info["rank_port"]
|
||||||
|
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
||||||
|
sock, lock = cls._connect(
|
||||||
|
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
||||||
|
)
|
||||||
|
return sock, lock
|
||||||
|
|
||||||
def _register_kv_args(self):
|
def _register_kv_args(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
|
|||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
||||||
|
from sglang.srt.utils import maybe_wrap_ipv6_address
|
||||||
|
|
||||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||||
1024 * 64
|
1024 * 64
|
||||||
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
|
|||||||
|
|
||||||
# Parse and transform prefill_server for bootstrap data
|
# Parse and transform prefill_server for bootstrap data
|
||||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
hostname = parsed_url.hostname
|
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||||
modified_request = request_data.copy()
|
modified_request = request_data.copy()
|
||||||
|
|
||||||
batch_size = _get_request_batch_size(modified_request)
|
batch_size = _get_request_batch_size(modified_request)
|
||||||
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
|||||||
|
|
||||||
# Parse and transform prefill_server for bootstrap data
|
# Parse and transform prefill_server for bootstrap data
|
||||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
hostname = parsed_url.hostname
|
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||||
modified_request = request_data.copy()
|
modified_request = request_data.copy()
|
||||||
modified_request.update(
|
modified_request.update(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
|
|||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
from sglang.srt.utils import (
|
||||||
|
format_tcp_address,
|
||||||
|
get_free_port,
|
||||||
|
get_int_env_var,
|
||||||
|
get_ip,
|
||||||
|
get_local_ip_auto,
|
||||||
|
is_valid_ipv6_address,
|
||||||
|
maybe_wrap_ipv6_address,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.request_status: Dict[int, KVPoll] = {}
|
self.request_status: Dict[int, KVPoll] = {}
|
||||||
self.rank_port = None
|
self.rank_port = None
|
||||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||||
|
if is_valid_ipv6_address(self.local_ip):
|
||||||
|
self.server_socket.setsockopt(zmq.IPV6, 1)
|
||||||
|
|
||||||
self.register_buffer_to_engine()
|
self.register_buffer_to_engine()
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||||
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.engine.register(aux_data_ptr, aux_data_len)
|
self.engine.register(aux_data_ptr, aux_data_len)
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _connect(self, endpoint: str):
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||||
socket = zmq.Context().socket(zmq.PUSH)
|
socket = zmq.Context().socket(zmq.PUSH)
|
||||||
|
if is_ipv6:
|
||||||
|
socket.setsockopt(zmq.IPV6, 1)
|
||||||
socket.connect(endpoint)
|
socket.connect(endpoint)
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
def sync_status_to_decode_endpoint(
|
def sync_status_to_decode_endpoint(
|
||||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||||
):
|
):
|
||||||
if ":" in remote:
|
self._connect(
|
||||||
remote = remote.split(":")[0]
|
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
||||||
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
).send_multipart(
|
||||||
[
|
[
|
||||||
str(room).encode("ascii"),
|
str(room).encode("ascii"),
|
||||||
str(status).encode("ascii"),
|
str(status).encode("ascii"),
|
||||||
@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _bind_server_socket(self):
|
||||||
|
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||||
|
|
||||||
def start_prefill_thread(self):
|
def start_prefill_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
self._bind_server_socket()
|
||||||
|
|
||||||
def bootstrap_thread():
|
def bootstrap_thread():
|
||||||
"""This thread recvs pre-alloc notification from the decode engine"""
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
||||||
@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
def start_decode_thread(self):
|
def start_decode_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
self._bind_server_socket()
|
||||||
|
|
||||||
def decode_thread():
|
def decode_thread():
|
||||||
while True:
|
while True:
|
||||||
@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
# requests with the same dst_sessions will be added into the same
|
# requests with the same dst_sessions will be added into the same
|
||||||
# queue, which enables early abort with failed sessions.
|
# queue, which enables early abort with failed sessions.
|
||||||
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
||||||
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
|
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
|
||||||
shard_idx = session_port_sum % len(self.transfer_queues)
|
shard_idx = session_port_sum % len(self.transfer_queues)
|
||||||
|
|
||||||
self.transfer_queues[shard_idx].put(
|
self.transfer_queues[shard_idx].put(
|
||||||
@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
def _register_to_bootstrap(self):
|
def _register_to_bootstrap(self):
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
if self.dist_init_addr:
|
if self.dist_init_addr:
|
||||||
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||||
|
if self.dist_init_addr.endswith("]"):
|
||||||
|
host = self.dist_init_addr
|
||||||
|
else:
|
||||||
|
host, _ = self.dist_init_addr.rsplit(":", 1)
|
||||||
|
else:
|
||||||
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||||
else:
|
else:
|
||||||
ip_address = get_ip()
|
host = get_ip()
|
||||||
|
host = maybe_wrap_ipv6_address(host)
|
||||||
|
|
||||||
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||||
url = f"http://{bootstrap_server_url}/route"
|
url = f"http://{bootstrap_server_url}/route"
|
||||||
payload = {
|
payload = {
|
||||||
"role": "Prefill",
|
"role": "Prefill",
|
||||||
@@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
|
|
||||||
def _register_kv_args(self):
|
def _register_kv_args(self):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
self.prefill_server_url = (
|
|
||||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
|
||||||
)
|
|
||||||
packed_kv_data_ptrs = b"".join(
|
packed_kv_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
)
|
)
|
||||||
@@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
dst_tp_size = str(tp_size).encode("ascii")
|
dst_tp_size = str(tp_size).encode("ascii")
|
||||||
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
||||||
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||||
with lock:
|
with lock:
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
@@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _connect(cls, endpoint: str):
|
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
||||||
with cls._global_lock:
|
with cls._global_lock:
|
||||||
if endpoint not in cls._socket_cache:
|
if endpoint not in cls._socket_cache:
|
||||||
sock = cls._ctx.socket(zmq.PUSH)
|
sock = cls._ctx.socket(zmq.PUSH)
|
||||||
|
if is_ipv6:
|
||||||
|
sock.setsockopt(zmq.IPV6, 1)
|
||||||
sock.connect(endpoint)
|
sock.connect(endpoint)
|
||||||
cls._socket_cache[endpoint] = sock
|
cls._socket_cache[endpoint] = sock
|
||||||
cls._socket_locks[endpoint] = threading.Lock()
|
cls._socket_locks[endpoint] = threading.Lock()
|
||||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
||||||
|
ip_address = bootstrap_info["rank_ip"]
|
||||||
|
port = bootstrap_info["rank_port"]
|
||||||
|
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
||||||
|
sock, lock = cls._connect(
|
||||||
|
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
||||||
|
)
|
||||||
|
return sock, lock
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
self.prefill_server_url = (
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
|
||||||
)
|
|
||||||
is_dummy = bootstrap_info["is_dummy"]
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
|
||||||
with lock:
|
with lock:
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var, get_free_port
|
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
|
|||||||
hostname=self.hostname,
|
hostname=self.hostname,
|
||||||
device_name=self.ib_device,
|
device_name=self.ib_device,
|
||||||
)
|
)
|
||||||
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
self.session_id = (
|
||||||
|
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
|
||||||
|
)
|
||||||
|
|
||||||
def register(self, ptr, length):
|
def register(self, ptr, length):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
|
|||||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_local_ip_by_remote
|
from sglang.srt.utils import (
|
||||||
|
format_tcp_address,
|
||||||
|
get_local_ip_auto,
|
||||||
|
is_valid_ipv6_address,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
|
|||||||
"to run SGLang with NixlTransferEngine."
|
"to run SGLang with NixlTransferEngine."
|
||||||
) from e
|
) from e
|
||||||
self.agent = nixl_agent(str(uuid.uuid4()))
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
||||||
|
self.local_ip = get_local_ip_auto()
|
||||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||||
|
if is_valid_ipv6_address(self.local_ip):
|
||||||
|
self.server_socket.setsockopt(zmq.IPV6, 1)
|
||||||
self.register_buffer_to_engine()
|
self.register_buffer_to_engine()
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
|
|||||||
return False
|
return False
|
||||||
return self.transfer_statuses[room].is_done()
|
return self.transfer_statuses[room].is_done()
|
||||||
|
|
||||||
|
def _bind_server_socket(self):
|
||||||
|
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||||
|
|
||||||
def _start_bootstrap_thread(self):
|
def _start_bootstrap_thread(self):
|
||||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
self._bind_server_socket()
|
||||||
|
|
||||||
def bootstrap_thread():
|
def bootstrap_thread():
|
||||||
"""This thread recvs transfer info from the decode engine"""
|
"""This thread recvs transfer info from the decode engine"""
|
||||||
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
self.prefill_server_url = (
|
|
||||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
)
|
)
|
||||||
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||||
is_dummy = bootstrap_info["is_dummy"]
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
||||||
)
|
)
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
|
||||||
with lock:
|
with lock:
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
GUARD,
|
GUARD,
|
||||||
str(self.bootstrap_room).encode("ascii"),
|
str(self.bootstrap_room).encode("ascii"),
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
self.kv_mgr.local_ip.encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.kv_mgr.agent.name.encode("ascii"),
|
self.kv_mgr.agent.name.encode("ascii"),
|
||||||
kv_indices.tobytes() if not is_dummy else b"",
|
kv_indices.tobytes() if not is_dummy else b"",
|
||||||
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
|
|
||||||
def _register_kv_args(self):
|
def _register_kv_args(self):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
self.prefill_server_url = (
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
|
||||||
)
|
|
||||||
packed_kv_data_ptrs = b"".join(
|
packed_kv_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
)
|
)
|
||||||
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||||
)
|
)
|
||||||
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
|
||||||
with lock:
|
with lock:
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
GUARD,
|
GUARD,
|
||||||
"None".encode("ascii"),
|
"None".encode("ascii"),
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
self.kv_mgr.local_ip.encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.kv_mgr.agent.name.encode("ascii"),
|
self.kv_mgr.agent.name.encode("ascii"),
|
||||||
self.kv_mgr.agent.get_agent_metadata(),
|
self.kv_mgr.agent.get_agent_metadata(),
|
||||||
|
|||||||
@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
|
|||||||
from zmq import IPV6 # type: ignore
|
from zmq import IPV6 # type: ignore
|
||||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||||
|
|
||||||
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
|
from sglang.srt.utils import (
|
||||||
|
format_tcp_address,
|
||||||
|
get_ip,
|
||||||
|
get_open_port,
|
||||||
|
is_valid_ipv6_address,
|
||||||
|
)
|
||||||
|
|
||||||
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
||||||
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
||||||
@@ -225,9 +230,9 @@ class MessageQueue:
|
|||||||
remote_subscribe_port = get_open_port()
|
remote_subscribe_port = get_open_port()
|
||||||
if is_valid_ipv6_address(connect_ip):
|
if is_valid_ipv6_address(connect_ip):
|
||||||
self.remote_socket.setsockopt(IPV6, 1)
|
self.remote_socket.setsockopt(IPV6, 1)
|
||||||
connect_ip = f"[{connect_ip}]"
|
self.remote_socket.bind(
|
||||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
format_tcp_address(connect_ip, remote_subscribe_port)
|
||||||
self.remote_socket.bind(socket_addr)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
remote_subscribe_port = None
|
remote_subscribe_port = None
|
||||||
@@ -288,7 +293,9 @@ class MessageQueue:
|
|||||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||||
if is_valid_ipv6_address(handle.connect_ip):
|
if is_valid_ipv6_address(handle.connect_ip):
|
||||||
self.remote_socket.setsockopt(IPV6, 1)
|
self.remote_socket.setsockopt(IPV6, 1)
|
||||||
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
socket_addr = format_tcp_address(
|
||||||
|
handle.connect_ip, handle.remote_subscribe_port
|
||||||
|
)
|
||||||
logger.debug("Connecting to %s", socket_addr)
|
logger.debug("Connecting to %s", socket_addr)
|
||||||
self.remote_socket.connect(socket_addr)
|
self.remote_socket.connect(socket_addr)
|
||||||
|
|
||||||
|
|||||||
@@ -2065,6 +2065,16 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||||
|
if is_valid_ipv6_address(address):
|
||||||
|
return f"[{address}]"
|
||||||
|
return address
|
||||||
|
|
||||||
|
|
||||||
|
def format_tcp_address(ip: str, port: int) -> str:
|
||||||
|
return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def configure_ipv6(dist_init_addr):
|
def configure_ipv6(dist_init_addr):
|
||||||
addr = dist_init_addr
|
addr = dist_init_addr
|
||||||
end = addr.find("]")
|
end = addr.find("]")
|
||||||
|
|||||||
Reference in New Issue
Block a user