From e806f708c954020bda7d1cc98035a44fd6a4eb96 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 27 May 2025 12:47:38 -0700 Subject: [PATCH] [PD] Make bootstrap code common between NIXL and Mooncake (#6473) --- docs/backend/pd_disaggregation.md | 41 ++ .../srt/disaggregation/common/__init__.py | 1 + .../sglang/srt/disaggregation/common/conn.py | 401 +++++++++++++ .../srt/disaggregation/mooncake/conn.py | 22 +- python/sglang/srt/disaggregation/nixl/conn.py | 541 +++++------------- python/sglang/srt/disaggregation/utils.py | 19 +- 6 files changed, 596 insertions(+), 429 deletions(-) create mode 100644 python/sglang/srt/disaggregation/common/__init__.py create mode 100644 python/sglang/srt/disaggregation/common/conn.py diff --git a/docs/backend/pd_disaggregation.md b/docs/backend/pd_disaggregation.md index de95763a6..e77164372 100644 --- a/docs/backend/pd_disaggregation.md +++ b/docs/backend/pd_disaggregation.md @@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis # decode 1 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 ``` + +## NIXL +### Requirements + +Install via pip. + +```bash +pip install nixl +``` + +Or build from source - may be required if you already have UCX installed. + +```bash +git clone https://github.com/ai-dynamo/nixl.git +cd nixl +pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx" +``` + + +### Usage + +### Llama Single Node + +```bash +$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl +$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl +$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 +``` + +### DeepSeek Multi-Node + +```bash +# prefill 0 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 +# prefill 1 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 +# decode 0 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 +# decode 1 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 +``` diff --git a/python/sglang/srt/disaggregation/common/__init__.py b/python/sglang/srt/disaggregation/common/__init__.py new file mode 100644 index 000000000..950db151f --- /dev/null +++ b/python/sglang/srt/disaggregation/common/__init__.py @@ -0,0 +1 @@ +from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py new file mode 100644 index 000000000..4d66c18af --- /dev/null +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +import asyncio +import logging +import socket +import threading +from functools import cache +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import numpy.typing as npt +import requests +import zmq +from aiohttp import web + +from sglang.srt.disaggregation.base.conn import ( + BaseKVBootstrapServer, + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote + +logger = logging.getLogger(__name__) + + +class CommonKVManager(BaseKVManager): + def __init__( + self, + args: KVArgs, + disaggregation_mode: DisaggregationMode, + server_args: ServerArgs, + is_mla_backend: Optional[bool] = False, + ): + self.kv_args = args + self.is_mla_backend = is_mla_backend + self.disaggregation_mode = disaggregation_mode + # for p/d multi node infer + self.bootstrap_port = server_args.disaggregation_bootstrap_port + self.dist_init_addr = server_args.dist_init_addr + self.tp_size = server_args.tp_size + self.dp_size = server_args.dp_size + self.enable_dp_attention = server_args.enable_dp_attention + if not server_args.enable_dp_attention and server_args.dp_size != 1: + raise ValueError( + "If dp_attention is not enabled, dp size must be 1 in disaggregation mode." + ) + + self.rank_port = get_free_port() + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self._register_to_bootstrap() + elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} + self.prefill_tp_size_table: Dict[str, int] = {} + self.prefill_dp_size_table: Dict[str, int] = {} + else: + raise ValueError( + f"Unsupported DisaggregationMode: {self.disaggregation_mode}" + ) + + def _register_to_bootstrap(self): + """Register KVSender to bootstrap server via HTTP POST.""" + if self.dist_init_addr: + ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) + else: + ip_address = get_ip() + + bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" + url = f"http://{bootstrap_server_url}/route" + payload = { + "role": "Prefill", + "tp_size": self.tp_size, + "dp_size": self.dp_size, + "rank_ip": get_local_ip_by_remote(), + "rank_port": self.rank_port, + "engine_rank": self.kv_args.engine_rank, + } + + try: + response = requests.put(url, json=payload) + if response.status_code == 200: + logger.debug("Prefill successfully registered to bootstrap server.") + else: + logger.error( + f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" + ) + except Exception as e: + logger.error(f"Prefill Failed to register to bootstrap server: {e}") + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + +class CommonKVReceiver(BaseKVReceiver): + _ctx = zmq.Context() + _socket_cache = {} + _socket_locks = {} + _global_lock = threading.Lock() + + def __init__( + self, + mgr: BaseKVManager, + bootstrap_addr: str, + bootstrap_room: Optional[int] = None, + ): + self.bootstrap_room = bootstrap_room + self.bootstrap_addr = bootstrap_addr + self.kv_mgr = mgr + + if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: + self.prefill_tp_size, self.prefill_dp_size = ( + self._get_prefill_dp_size_from_server() + ) + if self.prefill_tp_size is None or self.prefill_dp_size is None: + logger.error( + f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}" + ) + else: + self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( + self.prefill_tp_size + ) + self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( + self.prefill_dp_size + ) + else: + self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[ + self.bootstrap_addr + ] + self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ + self.bootstrap_addr + ] + + # Currently, we don't allow prefill instance and decode instance to + # have different TP sizes per DP rank, except for models using MLA. + local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size + prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size + if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank: + self.target_tp_rank = ( + self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + ) + self.required_dst_info_num = 1 + self.target_tp_ranks = [self.target_tp_rank] + elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: + assert ( + self.kv_mgr.is_mla_backend + ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" + self.target_tp_rank = ( + self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) + self.required_dst_info_num = ( + local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank + ) + self.target_tp_ranks = [self.target_tp_rank] + else: + assert ( + self.kv_mgr.is_mla_backend + ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" + + # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; + self.target_tp_ranks = [ + rank + for rank in range( + (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank) + * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), + (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1) + * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), + ) + ] + + # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain + # multiple connections in the connection pool and have to send dummy requests to other prefill ranks, + # or the KVPoll will never be set correctly + self.target_tp_rank = self.target_tp_ranks[0] + self.required_dst_info_num = 1 + + self.target_dp_group = bootstrap_room % self.prefill_dp_size + + # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank + bootstrap_key = ( + f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" + ) + + if bootstrap_key not in self.kv_mgr.connection_pool: + bootstrap_infos = [] + for target_tp_rank in self.target_tp_ranks: + bootstrap_info = self._get_bootstrap_info_from_server( + target_tp_rank, + self.target_dp_group, + ) + if bootstrap_info is not None: + # NOTE: only support MLA for now: select one prefill rank as real rank + bootstrap_info["is_dummy"] = not bool( + target_tp_rank == self.target_tp_rank + or self.target_tp_rank is None + ) + bootstrap_infos.append(bootstrap_info) + else: + logger.error( + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}" + ) + self.bootstrap_infos = bootstrap_infos + + if len(self.bootstrap_infos) == 0: + logger.error( + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) + else: + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos + # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server + self._register_kv_args() + else: + self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] + + assert len(self.bootstrap_infos) > 0 + + def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): + """Fetch the bootstrap info from the bootstrap server.""" + try: + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" + response = requests.get(url) + if response.status_code == 200: + bootstrap_info = response.json() + return bootstrap_info + else: + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) + return None + except Exception as e: + logger.error(f"Error fetching prefill info from bootstrap: {e}") + return None + + def _get_prefill_dp_size_from_server(self) -> int: + """Fetch the prefill parallel info from the bootstrap server.""" + try: + url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" + response = requests.get(url) + if response.status_code == 200: + prefill_parallel_info = response.json() + return int(prefill_parallel_info["prefill_tp_size"]), int( + prefill_parallel_info["prefill_dp_size"] + ) + else: + logger.error( + f"Failed to get prefill parallel info: {response.status_code}, {response.text}" + ) + return None + except Exception as e: + logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") + return None + + @classmethod + def _connect(cls, endpoint: str): + with cls._global_lock: + if endpoint not in cls._socket_cache: + sock = cls._ctx.socket(zmq.PUSH) + sock.connect(endpoint) + cls._socket_cache[endpoint] = sock + cls._socket_locks[endpoint] = threading.Lock() + return cls._socket_cache[endpoint], cls._socket_locks[endpoint] + + def _register_kv_args(self): + pass + + def failure_exception(self): + raise Exception("Fake KVReceiver Exception") + + +class CommonKVBootstrapServer(BaseKVBootstrapServer): + def __init__(self, port: int): + self.port = port + self.app = web.Application() + self.store = dict() + self.lock = asyncio.Lock() + self._setup_routes() + self.tp_size = None + self.dp_size = None + self.tp_size_per_dp_rank = None + self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {} + + # Start bootstrap server + self.thread = threading.Thread(target=self._run_server, daemon=True) + self.run() + + def run(self): + self.thread.start() + + def _setup_routes(self): + self.app.router.add_route("*", "/route", self._handle_route) + + async def _handle_route(self, request: web.Request): + method = request.method + if method == "PUT": + return await self._handle_route_put(request) + elif method == "GET": + return await self._handle_route_get(request) + else: + return web.Response( + text="Method not allowed", status=405, content_type="application/json" + ) + + async def _handle_route_put(self, request: web.Request): + data = await request.json() + role = data["role"] + tp_size = data["tp_size"] + dp_size = data["dp_size"] + rank_ip = data["rank_ip"] + rank_port = int(data["rank_port"]) + engine_rank = int(data["engine_rank"]) + + if self.tp_size is None: + self.tp_size = tp_size + + if self.dp_size is None: + self.dp_size = dp_size + + tp_size_per_dp_rank = tp_size // dp_size + if self.tp_size_per_dp_rank == None: + self.tp_size_per_dp_rank = tp_size_per_dp_rank + + # Add lock to make sure thread-safe + if role == "Prefill": + dp_group = engine_rank // tp_size_per_dp_rank + tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank + + async with self.lock: + if dp_group not in self.prefill_port_table: + self.prefill_port_table[dp_group] = {} + + self.prefill_port_table[dp_group][tp_rank_in_dp_group] = { + "rank_ip": rank_ip, + "rank_port": rank_port, + } + logger.debug( + f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + ) + + return web.Response(text="OK", status=200) + + async def _handle_route_get(self, request: web.Request): + engine_rank = request.query.get("engine_rank") + target_dp_group = request.query.get("target_dp_group") + if not engine_rank or not target_dp_group: + return web.Response(text="Missing inputs for bootstrap server.", status=400) + + # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size + if int(engine_rank) == -1 and int(target_dp_group) == -1: + prefill_parallel_info = { + "prefill_tp_size": self.tp_size, + "prefill_dp_size": self.dp_size, + } + return web.json_response(prefill_parallel_info, status=200) + + # Find corresponding prefill info + async with self.lock: + bootstrap_info = self.prefill_port_table[int(target_dp_group)][ + int(engine_rank) + ] + + if bootstrap_info is not None: + return web.json_response(bootstrap_info, status=200) + else: + return web.Response(text="Bootstrap info not Found", status=404) + + def _run_server(self): + try: + # Event Loop + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._runner = web.AppRunner(self.app) + self._loop.run_until_complete(self._runner.setup()) + + site = web.TCPSite(self._runner, port=self.port) + self._loop.run_until_complete(site.start()) + self._loop.run_forever() + except Exception as e: + logger.error(f"Server error: {str(e)}") + finally: + # Cleanup + self._loop.run_until_complete(self._runner.cleanup()) + self._loop.close() + + def close(self): + """Shutdown""" + if self._loop is not None and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + logger.info("Stopping server loop...") + + if self.thread.is_alive(): + self.thread.join(timeout=2) + logger.info("Server thread stopped") + + def poll(self) -> KVPoll: ... diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 3f5dc54ef..a03df0700 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import ( KVPoll, ) from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine -from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, + group_concurrent_contiguous, +) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_free_port, @@ -41,23 +44,6 @@ from sglang.srt.utils import ( logger = logging.getLogger(__name__) -def group_concurrent_contiguous( - src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] -) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: - """Vectorised NumPy implementation.""" - if src_indices.size == 0: - return [], [] - - brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 - src_groups = np.split(src_indices, brk) - dst_groups = np.split(dst_indices, brk) - - src_groups = [g.tolist() for g in src_groups] - dst_groups = [g.tolist() for g in dst_groups] - - return src_groups, dst_groups - - class KVTransferError(Exception): def __init__(self, bootstrap_room: int, failure_reason: str): super().__init__(failure_reason) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index feff93216..928dd7530 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -18,40 +18,23 @@ import requests import zmq from aiohttp import web -from sglang.srt.disaggregation.base.conn import ( - BaseKVBootstrapServer, - BaseKVManager, - BaseKVReceiver, - BaseKVSender, - KVArgs, - KVPoll, +from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll +from sglang.srt.disaggregation.common.conn import ( + CommonKVBootstrapServer, + CommonKVManager, + CommonKVReceiver, +) +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, + group_concurrent_contiguous, ) -from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote +from sglang.srt.utils import get_local_ip_by_remote logger = logging.getLogger(__name__) NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]] - -def group_concurrent_contiguous( - src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] -) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: - """Vectorised NumPy implementation.""" - if src_indices.size == 0: - return [], [] - - brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 - src_groups = np.split(src_indices, brk) - dst_groups = np.split(dst_indices, brk) - - src_groups = [g.tolist() for g in src_groups] - dst_groups = [g.tolist() for g in dst_groups] - - return src_groups, dst_groups - - GUARD = "NixlMsgGuard".encode("ascii") @@ -61,11 +44,13 @@ class TransferInfo: endpoint: str dst_port: int agent_metadata: bytes + agent_name: str dst_kv_ptrs: list[int] dst_kv_indices: npt.NDArray[np.int64] dst_aux_ptrs: list[int] dst_aux_index: int dst_gpu_id: int + required_dst_info_num: int def is_dummy(self): return self.endpoint == "" @@ -79,11 +64,13 @@ class TransferInfo: endpoint="", dst_port=0, agent_metadata=b"", + agent_name="", dst_kv_ptrs=[], dst_kv_indices=np.array([], dtype=np.int64), dst_aux_ptrs=[], dst_aux_index=0, dst_gpu_id=0, + required_dst_info_num=0, ) else: return cls( @@ -91,11 +78,13 @@ class TransferInfo: endpoint=msg[1].decode("ascii"), dst_port=int(msg[2].decode("ascii")), agent_metadata=msg[3], - dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), - dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), - dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), - dst_aux_index=int(msg[7].decode("ascii")), - dst_gpu_id=int(msg[8].decode("ascii")), + agent_name=msg[4].decode("ascii"), + dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), + dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64), + dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), + dst_aux_index=int(msg[8].decode("ascii")), + dst_gpu_id=int(msg[9].decode("ascii")), + required_dst_info_num=int(msg[10].decode("ascii")), ) @@ -116,7 +105,7 @@ class TransferStatus: return self.num_kvs_expected == len(self.received_kvs) and self.received_aux -class NixlKVManager(BaseKVManager): +class NixlKVManager(CommonKVManager): def __init__( self, args: KVArgs, @@ -124,6 +113,7 @@ class NixlKVManager(BaseKVManager): server_args: ServerArgs, is_mla_backend: Optional[bool] = False, ): + super().__init__(args, disaggregation_mode, server_args, is_mla_backend) try: from nixl._api import nixl_agent except ImportError as e: @@ -133,38 +123,15 @@ class NixlKVManager(BaseKVManager): "to run SGLang with NixlTransferEngine." ) from e self.agent = nixl_agent(str(uuid.uuid4())) - self.kv_args = args - self.disaggregation_mode = disaggregation_mode - # for p/d multi node infer - self.bootstrap_port = server_args.disaggregation_bootstrap_port - self.dist_init_addr = server_args.dist_init_addr - self.tp_size = server_args.tp_size - - self.tp_rank = args.engine_rank - self.enable_dp_attention = server_args.enable_dp_attention - if self.enable_dp_attention: - assert ( - server_args.dp_size > 1 - ), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode." - self.dp_size = server_args.dp_size - self.tp_size_of_dp = server_args.tp_size // server_args.dp_size - self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp - self.dp_rank = args.engine_rank // self.tp_size_of_dp - - self.rank_port = None self.server_socket = zmq.Context().socket(zmq.PULL) self.register_buffer_to_engine() - self.rank_port = get_free_port() if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.request_status = {} self.transfer_infos: Dict[int, TransferInfo] = {} - self.condition = threading.Condition() - self.peer_names: Dict[int, str] = {} + self.peer_names: Dict[str, str] = {} self._start_bootstrap_thread() - self._register_to_bootstrap() elif self.disaggregation_mode == DisaggregationMode.DECODE: - # bootstrap key -> (remote_engine_rank -> possible remote source info) - self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {} self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( TransferStatus ) @@ -173,6 +140,18 @@ class NixlKVManager(BaseKVManager): f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) + def check_status(self, bootstrap_room: int): + return self.request_status[bootstrap_room] + + def update_status(self, bootstrap_room: int, status: KVPoll): + if bootstrap_room not in self.request_status: + self.request_status[bootstrap_room] = status + else: + # NOTE: The prefill engine could recv bootstrapping first + self.request_status[bootstrap_room] = max( + self.request_status[bootstrap_room], status + ) + def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( @@ -193,16 +172,10 @@ class NixlKVManager(BaseKVManager): if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") - @cache - def _connect(self, endpoint: str): - socket = zmq.Context().socket(zmq.PUSH) - socket.connect(endpoint) - return socket - - def _add_remote(self, room: int, agent_metadata: bytes): - if room not in self.peer_names: - self.peer_names[room] = self.agent.add_remote_agent(agent_metadata) - return self.peer_names[room] + def _add_remote(self, agent_name: str, agent_metadata: bytes): + if agent_name not in self.peer_names: + self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata) + return self.peer_names[agent_name] def send_kvcache( self, @@ -300,40 +273,38 @@ class NixlKVManager(BaseKVManager): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) - # Wait for transfer info to be populated by bootstrap thread. - with self.condition: - self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos) - req = self.transfer_infos[bootstrap_room] - assert bootstrap_room == req.room + reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() + handles = [] + for req in reqs_to_be_processed: + assert bootstrap_room == req.room + if req.is_dummy(): + return [] - if req.is_dummy(): - return [] + peer_name = self._add_remote(req.agent_name, req.agent_metadata) + chunked_dst_kv_indice = req.dst_kv_indices[index_slice] + assert len(chunked_dst_kv_indice) == len(kv_indices) - peer_name = self._add_remote(bootstrap_room, req.agent_metadata) - chunked_dst_kv_indice = req.dst_kv_indices[index_slice] - assert len(chunked_dst_kv_indice) == len(kv_indices) - - notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) - kv_xfer_handle = self.send_kvcache( - peer_name, - kv_indices, - req.dst_kv_ptrs, - chunked_dst_kv_indice, - req.dst_gpu_id, - notif, - ) - handles = [kv_xfer_handle] - # Only the last chunk we need to send the aux data. - if is_last: - assert aux_index is not None - aux_xfer_handle = self.send_aux( + notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) + kv_xfer_handle = self.send_kvcache( peer_name, - aux_index, - req.dst_aux_ptrs, - req.dst_aux_index, - str(req.room) + "_aux", + kv_indices, + req.dst_kv_ptrs, + chunked_dst_kv_indice, + req.dst_gpu_id, + notif, ) - handles.append(aux_xfer_handle) + handles.append(kv_xfer_handle) + # Only the last chunk we need to send the aux data. + if is_last: + assert aux_index is not None + aux_xfer_handle = self.send_aux( + peer_name, + aux_index, + req.dst_aux_ptrs, + req.dst_aux_index, + str(req.room) + "_aux", + ) + handles.append(aux_xfer_handle) return handles def update_transfer_status(self): @@ -348,7 +319,7 @@ class NixlKVManager(BaseKVManager): room = int(components[0]) if components[1] == "kv": chunk_id = int(components[2]) - is_last = bool(components[3]) + is_last = bool(int(components[3])) self.transfer_statuses[room].received_kvs.add(chunk_id) if is_last: self.transfer_statuses[room].num_kvs_expected = chunk_id + 1 @@ -360,34 +331,6 @@ class NixlKVManager(BaseKVManager): return False return self.transfer_statuses[room].is_done() - def _register_to_bootstrap(self): - """Register KVSender to bootstrap server via HTTP POST.""" - if self.dist_init_addr: - ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) - else: - ip_address = get_ip() - - bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" - url = f"http://{bootstrap_server_url}/route" - payload = { - "role": "Prefill", - "rank_ip": get_local_ip_by_remote(), - "rank_port": self.rank_port, - "engine_rank": self.kv_args.engine_rank, - "agent_name": self.agent.name, - } - - try: - response = requests.put(url, json=payload) - if response.status_code == 200: - logger.debug("Prefill successfully registered to bootstrap server.") - else: - logger.error( - f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" - ) - except Exception as e: - logger.error(f"Prefill Failed to register to bootstrap server: {e}") - def _start_bootstrap_thread(self): self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") @@ -405,10 +348,19 @@ class NixlKVManager(BaseKVManager): room = waiting_req_bytes[0].decode("ascii") if room == "None": continue + required_dst_info_num = int(waiting_req_bytes[10].decode("ascii")) room = int(room) - with self.condition: - self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) - self.condition.notify_all() + agent_name = waiting_req_bytes[4].decode("ascii") + if room not in self.transfer_infos: + self.transfer_infos[room] = {} + self.transfer_infos[room][agent_name] = TransferInfo.from_zmq( + waiting_req_bytes + ) + + logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}") + if len(self.transfer_infos[room]) == required_dst_info_num: + logger.debug(f"{room=} is bootstrapped") + self.update_status(room, KVPoll.WaitingForInput) threading.Thread(target=bootstrap_thread).start() @@ -423,6 +375,9 @@ class NixlKVSender(BaseKVSender): self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + # inner state + self.curr_idx = 0 def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices @@ -431,9 +386,11 @@ class NixlKVSender(BaseKVSender): def send( self, kv_indices: npt.NDArray[np.int64], - index_slice: slice, - is_last: bool, ): + index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) + self.curr_idx += len(kv_indices) + is_last = self.curr_idx == self.num_kv_indices + new_xfer_handles = self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, @@ -449,7 +406,7 @@ class NixlKVSender(BaseKVSender): def poll(self) -> KVPoll: if not self.has_sent: - return KVPoll.WaitingForInput # type: ignore + return self.kv_mgr.check_status(self.bootstrap_room) states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] if all([x == "DONE" for x in states]): return KVPoll.Success # type: ignore @@ -461,128 +418,40 @@ class NixlKVSender(BaseKVSender): raise Exception("Fake KVSender Exception") -class NixlKVReceiver(BaseKVReceiver): - +class NixlKVReceiver(CommonKVReceiver): def __init__( self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, ): - self.bootstrap_room = bootstrap_room - self.bootstrap_addr = bootstrap_addr - self.kv_mgr = mgr self.started_transfer = False - - # NOTE: key distinguished by bootstrap_addr and engine_rank - bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" - - if bootstrap_key not in self.kv_mgr.prefill_peer_infos: - self.bootstrap_info = self._get_bootstrap_info_from_server( - self.kv_mgr.kv_args.engine_rank - ) - if self.bootstrap_info is None: - logger.error( - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" - ) - else: - self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info - else: - self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key] - assert self.bootstrap_info is not None - - # return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...] - # In each dict, there are multiple possible remotes named "equal sources". - # We only need to select one to split the traffic. i.e. we totally select len(list) remotes. - def _get_bootstrap_info_from_server( - self, engine_rank - ) -> Optional[List[Dict[int, NixlEngineInfo]]]: - """Fetch the bootstrap info from the bootstrap server.""" - try: - if self.kv_mgr.enable_dp_attention: - url = f"http://{self.bootstrap_addr}/route" - response = requests.get(url) - if response.status_code != 200: - logger.error( - f"Failed to get prefill server info: {response.status_code}, {response.text}" - ) - return None - - bootstrap_info = response.json() - assert isinstance(bootstrap_info, dict) - bootstrap_info = {int(k): v for k, v in bootstrap_info.items()} - - # split out who need to send to this rank. - # currently for dpsk mla model, those ranks share the same latent cache. - # pick one as the real source - - prefill_tp_size = len(bootstrap_info.keys()) - - assert ( - prefill_tp_size >= self.kv_mgr.tp_size_of_dp - ), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}" - - num_remote_tp_rank_we_managed = ( - prefill_tp_size // self.kv_mgr.tp_size_of_dp - ) - - # We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num) - remote_tp_ranks = list(range(0, prefill_tp_size)) - # split it into tp_size_of_dp parts and get our part - remote_tp_ranks_grouped = [ - remote_tp_ranks[i : i + num_remote_tp_rank_we_managed] - for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp) - ] - managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank] - - assert len(managed_ranks) == num_remote_tp_rank_we_managed - - logger.debug( - f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}" - ) - - return [ - { - rk: bootstrap_info[rk] - for rk in bootstrap_info.keys() - if rk in managed_ranks - } - ] - else: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" - response = requests.get(url) - if response.status_code == 200: - bootstrap_info = response.json() - return [{engine_rank: bootstrap_info}] - else: - logger.error( - f"Failed to get prefill server info: {response.status_code}, {response.text}" - ) - return None - except Exception as e: - logger.error(f"Error fetching prefill info from bootstrap: {e}") - return None - - @cache - def _connect(self, endpoint: str): - socket = zmq.Context().socket(zmq.PUSH) - socket.connect(endpoint) - return socket + super().__init__(mgr, bootstrap_addr, bootstrap_room) def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): - - assert self.bootstrap_info is not None - assert self.bootstrap_room is not None - - for equal_sources in self.bootstrap_info: - remote_rank = list(equal_sources.keys())[ - self.bootstrap_room % len(equal_sources) - ] - self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}" - logger.debug( - f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}" + for bootstrap_info in self.bootstrap_infos: + self.prefill_server_url = ( + f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" ) + logger.debug( + f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) + is_dummy = bootstrap_info["is_dummy"] + # TODO: just send "" for indices for dummy + if is_dummy: + # TODO: need to set success?? + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( + [ + GUARD, + str(self.bootstrap_room).encode("ascii"), + ] + ) + continue + + # TODO: send_kv_args earlier packed_kv_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs ) @@ -593,30 +462,22 @@ class NixlKVReceiver(BaseKVReceiver): logger.debug( f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}" ) - self._connect("tcp://" + self.prefill_server_url).send_multipart( - [ - GUARD, - str(self.bootstrap_room).encode("ascii"), - get_local_ip_by_remote().encode("ascii"), - str(self.kv_mgr.rank_port).encode("ascii"), - self.kv_mgr.agent.get_agent_metadata(), - packed_kv_data_ptrs, - kv_indices.tobytes(), - packed_aux_data_ptrs, - str(aux_index).encode("ascii"), - str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), - ] - ) - - for dummy_rank in equal_sources.keys(): - if dummy_rank == remote_rank: - continue - dummy_info = equal_sources[dummy_rank] - dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}" - self._connect("tcp://" + dummy_url).send_multipart( + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( [ GUARD, str(self.bootstrap_room).encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.kv_mgr.agent.get_agent_metadata(), + self.kv_mgr.agent.name.encode("ascii"), + packed_kv_data_ptrs, + kv_indices.tobytes(), + packed_aux_data_ptrs, + str(aux_index).encode("ascii"), + str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), + str(self.required_dst_info_num).encode("ascii"), ] ) @@ -632,152 +493,12 @@ class NixlKVReceiver(BaseKVReceiver): return KVPoll.Success # type: ignore return KVPoll.WaitingForInput # type: ignore + def _register_kv_args(self): + pass + def failure_exception(self): raise Exception("Fake KVReceiver Exception") -class NixlKVBootstrapServer(BaseKVBootstrapServer): - def __init__(self, port: int): - logger.debug(f"NixlKVBootstrapServer started on port {port}") - self.port = port - self.app = web.Application() - self.store = dict() - self.lock = asyncio.Lock() - self._setup_routes() - self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {} - - # Start bootstrap server - self.thread = threading.Thread(target=self._run_server, daemon=True) - self.run() - - def run(self): - self.thread.start() - - def _setup_routes(self): - self.app.router.add_route("*", "/metadata", self._handle_metadata) - self.app.router.add_route("*", "/route", self._handle_route) - - async def _handle_metadata(self, request: web.Request): - key = request.query.get("key", "") - - if request.method == "GET": - return await self._handle_metadata_get(key) - elif request.method == "PUT": - return await self._handle_metadata_put(key, request) - elif request.method == "DELETE": - return await self._handle_metadata_delete(key) - return web.Response( - text="Method not allowed", status=405, content_type="application/json" - ) - - async def _handle_metadata_get(self, key): - async with self.lock: - value = self.store.get(key) - if value is None: - return web.Response( - text="metadata not found", status=404, content_type="application/json" - ) - return web.Response(body=value, status=200, content_type="application/json") - - async def _handle_metadata_put(self, key, request): - data = await request.read() - async with self.lock: - self.store[key] = data - return web.Response( - text="metadata updated", status=200, content_type="application/json" - ) - - async def _handle_metadata_delete(self, key): - async with self.lock: - if key not in self.store: - return web.Response( - text="metadata not found", - status=404, - content_type="application/json", - ) - del self.store[key] - return web.Response( - text="metadata deleted", status=200, content_type="application/json" - ) - - async def _handle_route(self, request: web.Request): - method = request.method - if method == "PUT": - return await self._handle_route_put(request) - elif method == "GET": - return await self._handle_route_get(request) - else: - return web.Response( - text="Method not allowed", status=405, content_type="application/json" - ) - - async def _handle_route_put(self, request: web.Request): - data = await request.json() - role = data["role"] - rank_ip = data["rank_ip"] - rank_port = int(data["rank_port"]) - engine_rank = int(data["engine_rank"]) - agent_name = data["agent_name"] - - if role == "Prefill": - async with self.lock: - self.prefill_port_table[engine_rank] = { - "rank_ip": rank_ip, - "rank_port": rank_port, - "agent_name": agent_name, - } - logger.info( - f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}" - ) - - return web.Response(text="OK", status=200) - - async def _handle_route_get(self, request: web.Request): - engine_rank = request.query.get("engine_rank") - if not engine_rank: - logger.debug( - f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict" - ) - # Return a dict of all engine_rank - async with self.lock: - bootstrap_info = self.prefill_port_table - return web.json_response(bootstrap_info, status=200) - - # Find corresponding prefill info - async with self.lock: - bootstrap_info = self.prefill_port_table.get(int(engine_rank)) - if bootstrap_info is not None: - return web.json_response(bootstrap_info, status=200) - else: - return web.Response(text="Not Found", status=404) - - def _run_server(self): - try: - # Event Loop - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._runner = web.AppRunner(self.app) - self._loop.run_until_complete(self._runner.setup()) - - site = web.TCPSite(self._runner, port=self.port) - self._loop.run_until_complete(site.start()) - self._loop.run_forever() - except Exception as e: - logger.error(f"Server error: {str(e)}") - finally: - # Cleanup - self._loop.run_until_complete(self._runner.cleanup()) - self._loop.close() - - def close(self): - """Shutdown""" - if self._loop is not None and self._loop.is_running(): - self._loop.call_soon_threadsafe(self._loop.stop) - logger.info("Stopping server loop...") - - if self.thread.is_alive(): - self.thread.join(timeout=2) - logger.info("Server thread stopped") - - def poll(self) -> KVPoll: ... +class NixlKVBootstrapServer(CommonKVBootstrapServer): + pass diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 74923cd89..8841d5f1a 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -13,7 +13,7 @@ import requests import torch import torch.distributed as dist -from sglang.srt.utils import get_ip +from sglang.srt.utils import get_ip, get_local_ip_by_remote if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -279,3 +279,20 @@ class MetadataBuffers: ] = torch.tensor( req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" ) + + +def group_concurrent_contiguous( + src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + """Vectorised NumPy implementation.""" + if src_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups