Files
sglang/python/sglang/srt/disaggregation/common/conn.py
Stepan Kargaltsev 1b9cea5ade [P/D] Support ipv6 in P/D scenario (#7858)
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
2025-07-25 08:53:30 -07:00

436 lines
17 KiB
Python

from __future__ import annotations
import asyncio
import logging
import socket
import threading
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_ip,
get_local_ip_by_remote,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__)
class CommonKVManager(BaseKVManager):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.enable_dp_attention = server_args.enable_dp_attention
if not server_args.enable_dp_attention and server_args.dp_size != 1:
raise ValueError(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self.rank_port = get_free_port()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
host = get_ip()
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
}
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
class CommonKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
_global_lock = threading.Lock()
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
)
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
)
else:
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
)
self.required_dst_info_num = 1
self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
)
self.target_tp_ranks = [self.target_tp_rank]
else:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
)
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank,
self.target_dp_group,
)
if bootstrap_info is not None:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
bootstrap_infos.append(bootstrap_info)
else:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
)
self.bootstrap_infos = bootstrap_infos
if len(self.bootstrap_infos) == 0:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_dp_size_from_server(self) -> int:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_tp_size"]), int(
prefill_parallel_info["prefill_dp_size"]
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None
@classmethod
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def _register_kv_args(self):
pass
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class CommonKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int):
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.tp_size = None
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
tp_size = data["tp_size"]
dp_size = data["dp_size"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
if self.tp_size is None:
self.tp_size = tp_size
if self.dp_size is None:
self.dp_size = dp_size
tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank
# Add lock to make sure thread-safe
if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
prefill_parallel_info = {
"prefill_tp_size": self.tp_size,
"prefill_dp_size": self.dp_size,
}
return web.json_response(prefill_parallel_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Bootstrap info not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...