From 4dce1cc60843eb1fa42c85f9c6818726c497e84a Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 21 Apr 2025 10:36:12 -0700 Subject: [PATCH] [PD] Add NIXL transfer backend (#5477) --- .../srt/disaggregation/nixl/__init__.py | 1 + python/sglang/srt/disaggregation/nixl/conn.py | 622 ++++++++++++++++++ python/sglang/srt/disaggregation/utils.py | 16 + python/sglang/srt/server_args.py | 1 + 4 files changed, 640 insertions(+) create mode 100644 python/sglang/srt/disaggregation/nixl/__init__.py create mode 100644 python/sglang/srt/disaggregation/nixl/conn.py diff --git a/python/sglang/srt/disaggregation/nixl/__init__.py b/python/sglang/srt/disaggregation/nixl/__init__.py new file mode 100644 index 000000000..d9456e0fd --- /dev/null +++ b/python/sglang/srt/disaggregation/nixl/__init__.py @@ -0,0 +1 @@ +from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py new file mode 100644 index 000000000..080babb44 --- /dev/null +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -0,0 +1,622 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import logging +import queue +import socket +import struct +import threading +import uuid +from collections import defaultdict +from functools import cache +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import numpy.typing as npt +import requests +import zmq +from aiohttp import web + +from sglang.srt.disaggregation.base.conn import ( + BaseKVBootstrapServer, + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class TransferInfo: + room: int + endpoint: str + dst_port: int + agent_metadata: bytes + dst_kv_ptrs: list[int] + dst_kv_indices: npt.NDArray[np.int64] + dst_aux_ptrs: list[int] + dst_aux_index: int + dst_gpu_id: int + + @classmethod + def from_zmq(cls, msg: List[bytes]): + return cls( + room=int(msg[0].decode("ascii")), + endpoint=msg[1].decode("ascii"), + dst_port=int(msg[2].decode("ascii")), + agent_metadata=msg[3], + dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), + dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), + dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), + dst_aux_index=int(msg[7].decode("ascii")), + dst_gpu_id=int(msg[8].decode("ascii")), + ) + + +@dataclasses.dataclass +class TransferStatus: + """Used by KV Receiver to know when a transfer is done.""" + + # KV chunk IDs that have been received. + received_kvs: Set[int] = dataclasses.field(default_factory=set) + # Number of kv chunks to expect, will know this after last chunk is received. + num_kvs_expected: Optional[int] = None + # Whether aux data has been received. + received_aux: bool = False + + def is_done(self): + if self.num_kvs_expected is None: + return False + return self.num_kvs_expected == len(self.received_kvs) and self.received_aux + + +class NixlKVManager(BaseKVManager): + def __init__( + self, + args: KVArgs, + disaggregation_mode: DisaggregationMode, + server_args: ServerArgs, + ): + try: + from nixl._api import nixl_agent + except ImportError as e: + raise ImportError( + "Please install NIXL by following the instructions at " + "https://github.com/ai-dynamo/nixl/blob/main/README.md " + "to run SGLang with NixlTransferEngine." + ) from e + self.agent = nixl_agent(str(uuid.uuid4())) + self.kv_args = args + self.disaggregation_mode = disaggregation_mode + # for p/d multi node infer + self.bootstrap_port = server_args.disaggregation_bootstrap_port + self.dist_init_addr = server_args.dist_init_addr + self.rank_port = None + self.server_socket = zmq.Context().socket(zmq.PULL) + self.register_buffer_to_engine() + + self.rank_port = get_free_port() + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.transfer_infos: Dict[int, TransferInfo] = {} + self.condition = threading.Condition() + self.peer_names: Dict[int, str] = {} + self._start_bootstrap_thread() + self._register_to_bootstrap() + elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} + self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( + TransferStatus + ) + else: + raise ValueError( + f"Unsupported DisaggregationMode: {self.disaggregation_mode}" + ) + + def register_buffer_to_engine(self): + kv_addrs = [] + for kv_data_ptr, kv_data_len in zip( + self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens + ): + kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) + self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True) + if not self.kv_descs: + raise Exception("NIXL memory registration failed for kv tensors") + aux_addrs = [] + for aux_data_ptr, aux_data_len in zip( + self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens + ): + aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) + self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True) + if not self.aux_descs: + raise Exception("NIXL memory registration failed for aux tensors") + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def _add_remote(self, room: int, agent_metadata: bytes): + if room not in self.peer_names: + self.peer_names[room] = self.agent.add_remote_agent(agent_metadata) + return self.peer_names[room] + + def send_kvcache( + self, + peer_name: str, + prefill_kv_indices: npt.NDArray[np.int64], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int64], + dst_gpu_id: int, + notif: str, + ): + # Make descs + num_layers = len(self.kv_args.kv_data_ptrs) + src_addrs = [] + dst_addrs = [] + for layer_id in range(num_layers): + src_ptr = self.kv_args.kv_data_ptrs[layer_id] + dst_ptr = dst_kv_ptrs[layer_id] + item_len = self.kv_args.kv_item_lens[layer_id] + + for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices): + src_addr = src_ptr + int(prefill_index) * item_len + dst_addr = dst_ptr + int(decode_index) * item_len + length = item_len + src_addrs.append((src_addr, length, self.kv_args.gpu_id)) + dst_addrs.append((dst_addr, length, dst_gpu_id)) + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True) + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True) + # Transfer data + xfer_handle = self.agent.initialize_xfer( + "WRITE", + src_descs, + dst_descs, + peer_name, + notif.encode("ascii"), + ) + if not xfer_handle: + raise Exception("KVSender failed to create transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("KVSender failed to post transfer") + return xfer_handle + + def send_aux( + self, + peer_name: str, + prefill_aux_index: int, + dst_aux_ptrs: list[int], + dst_aux_index: int, + notif: str, + ): + # Make descs + aux_item_len = self.kv_args.aux_item_lens[0] + prefill_aux_addr = ( + self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len + ) + decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len + src_addrs = [(prefill_aux_addr, aux_item_len, 0)] + dst_addrs = [(decode_aux_addr, aux_item_len, 0)] + src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True) + dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True) + # Transfer data + xfer_handle = self.agent.initialize_xfer( + "WRITE", + src_descs, + dst_descs, + peer_name, + notif.encode("ascii"), + ) + if not xfer_handle: + raise Exception("KVSender failed to create transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("KVSender failed to post transfer") + return xfer_handle + + def add_transfer_request( + self, + bootstrap_room: int, + kv_indices: npt.NDArray[np.int64], + index_slice: slice, + is_last: bool, + chunk_id: int, + aux_index: Optional[int] = None, + ): + assert self.disaggregation_mode == DisaggregationMode.PREFILL + assert not is_last or (is_last and aux_index is not None) + + # Wait for transfer info to be populated by bootstrap thread. + with self.condition: + self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos) + req = self.transfer_infos[bootstrap_room] + assert bootstrap_room == req.room + + peer_name = self._add_remote(bootstrap_room, req.agent_metadata) + chunked_dst_kv_indice = req.dst_kv_indices[index_slice] + assert len(chunked_dst_kv_indice) == len(kv_indices) + + notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) + kv_xfer_handle = self.send_kvcache( + peer_name, + kv_indices, + req.dst_kv_ptrs, + chunked_dst_kv_indice, + req.dst_gpu_id, + notif, + ) + handles = [kv_xfer_handle] + # Only the last chunk we need to send the aux data. + if is_last: + aux_xfer_handle = self.send_aux( + peer_name, + aux_index, + req.dst_aux_ptrs, + req.dst_aux_index, + str(req.room) + "_aux", + ) + handles.append(aux_xfer_handle) + return handles + + def update_transfer_status(self): + # Process notifications from received transfers. + notif_map = self.agent.get_new_notifs() + for peer_name, messages in notif_map.items(): + # We could also check that self.bootstrap_info['agent_name'] matches + # the message sender. But the bootstrap room alone should be + # sufficient to map the status. + for msg in messages: + components = msg.decode("ascii").split("_") + room = int(components[0]) + if components[1] == "kv": + chunk_id = int(components[2]) + is_last = bool(components[3]) + self.transfer_statuses[room].received_kvs.add(chunk_id) + if is_last: + self.transfer_statuses[room].num_kvs_expected = chunk_id + 1 + elif components[1] == "aux": + self.transfer_statuses[room].received_aux = True + + def check_transfer_done(self, room: int): + if room not in self.transfer_statuses: + return False + return self.transfer_statuses[room].is_done() + + def _register_to_bootstrap(self): + """Register KVSender to bootstrap server via HTTP POST.""" + if self.dist_init_addr: + ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) + else: + ip_address = get_ip() + + bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" + url = f"http://{bootstrap_server_url}/route" + payload = { + "role": "Prefill", + "rank_ip": get_local_ip_by_remote(), + "rank_port": self.rank_port, + "engine_rank": self.kv_args.engine_rank, + "agent_name": self.agent.name, + } + + try: + response = requests.put(url, json=payload) + if response.status_code == 200: + logger.debug("Prefill successfully registered to bootstrap server.") + else: + logger.error( + f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" + ) + except Exception as e: + logger.error(f"Prefill Failed to register to bootstrap server: {e}") + + def _start_bootstrap_thread(self): + self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") + + def bootstrap_thread(): + """This thread recvs transfer info from the decode engine""" + while True: + waiting_req_bytes = self.server_socket.recv_multipart() + room = waiting_req_bytes[0].decode("ascii") + if room == "None": + continue + room = int(room) + with self.condition: + self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) + self.condition.notify_all() + + threading.Thread(target=bootstrap_thread).start() + + +class NixlKVSender(BaseKVSender): + + def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int): + self.kv_mgr = mgr + self.bootstrap_room = bootstrap_room + self.aux_index = None + self.bootstrap_server_url = bootstrap_addr + self.xfer_handles = [] + self.has_sent = False + self.chunk_id = 0 + + def init(self, num_kv_indices: int, aux_index: Optional[int] = None): + self.num_kv_indices = num_kv_indices + self.aux_index = aux_index + + def send( + self, + kv_indices: npt.NDArray[np.int64], + index_slice: slice, + is_last: bool, + ): + new_xfer_handles = self.kv_mgr.add_transfer_request( + self.bootstrap_room, + kv_indices, + index_slice, + is_last, + self.chunk_id, + self.aux_index, + ) + self.xfer_handles.extend(new_xfer_handles) + self.chunk_id += 1 + if is_last: + self.has_sent = True + + def poll(self) -> KVPoll: + if not self.has_sent: + return KVPoll.WaitingForInput + + states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] + if all([x == "DONE" for x in states]): + return KVPoll.Success + if any([x == "ERR" for x in states]): + raise Exception("KVSender transfer encountered an error.") + return KVPoll.WaitingForInput + + def failure_exception(self): + raise Exception("Fake KVSender Exception") + + +class NixlKVReceiver(BaseKVReceiver): + + def __init__( + self, + mgr: NixlKVManager, + bootstrap_addr: str, + bootstrap_room: Optional[int] = None, + ): + self.bootstrap_room = bootstrap_room + self.bootstrap_addr = bootstrap_addr + self.kv_mgr = mgr + self.started_transfer = False + + # NOTE: key distinguished by bootstrap_addr and engine_rank + bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" + + if bootstrap_key not in self.kv_mgr.connection_pool: + self.bootstrap_info = self._get_bootstrap_info_from_server( + self.kv_mgr.kv_args.engine_rank + ) + if self.bootstrap_info is None: + logger.error( + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) + else: + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info + else: + self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] + + assert self.bootstrap_info is not None + + def _get_bootstrap_info_from_server(self, engine_rank): + """Fetch the bootstrap info from the bootstrap server.""" + try: + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" + response = requests.get(url) + if response.status_code == 200: + bootstrap_info = response.json() + return bootstrap_info + else: + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) + return None + except Exception as e: + logger.error(f"Error fetching prefill info from bootstrap: {e}") + return None + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): + self.prefill_server_url = ( + f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" + ) + logger.debug( + f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) + + packed_kv_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs + ) + packed_aux_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs + ) + self._connect("tcp://" + self.prefill_server_url).send_multipart( + [ + str(self.bootstrap_room).encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.kv_mgr.agent.get_agent_metadata(), + packed_kv_data_ptrs, + kv_indices.tobytes(), + packed_aux_data_ptrs, + str(aux_index).encode("ascii"), + str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), + ] + ) + self.started_transfer = True + + def poll(self) -> KVPoll: + if not self.started_transfer: + return KVPoll.WaitingForInput + + self.kv_mgr.update_transfer_status() + + if self.kv_mgr.check_transfer_done(self.bootstrap_room): + return KVPoll.Success + return KVPoll.WaitingForInput + + def failure_exception(self): + raise Exception("Fake KVReceiver Exception") + + +class NixlKVBootstrapServer(BaseKVBootstrapServer): + def __init__(self, port: int): + self.port = port + self.app = web.Application() + self.store = dict() + self.lock = asyncio.Lock() + self._setup_routes() + self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {} + + # Start bootstrap server + self.thread = threading.Thread(target=self._run_server, daemon=True) + self.run() + + def run(self): + self.thread.start() + + def _setup_routes(self): + self.app.router.add_route("*", "/metadata", self._handle_metadata) + self.app.router.add_route("*", "/route", self._handle_route) + + async def _handle_metadata(self, request: web.Request): + key = request.query.get("key", "") + + if request.method == "GET": + return await self._handle_metadata_get(key) + elif request.method == "PUT": + return await self._handle_metadata_put(key, request) + elif request.method == "DELETE": + return await self._handle_metadata_delete(key) + return web.Response( + text="Method not allowed", status=405, content_type="application/json" + ) + + async def _handle_metadata_get(self, key): + async with self.lock: + value = self.store.get(key) + if value is None: + return web.Response( + text="metadata not found", status=404, content_type="application/json" + ) + return web.Response(body=value, status=200, content_type="application/json") + + async def _handle_metadata_put(self, key, request): + data = await request.read() + async with self.lock: + self.store[key] = data + return web.Response( + text="metadata updated", status=200, content_type="application/json" + ) + + async def _handle_metadata_delete(self, key): + async with self.lock: + if key not in self.store: + return web.Response( + text="metadata not found", + status=404, + content_type="application/json", + ) + del self.store[key] + return web.Response( + text="metadata deleted", status=200, content_type="application/json" + ) + + async def _handle_route(self, request: web.Request): + method = request.method + if method == "PUT": + return await self._handle_route_put(request) + elif method == "GET": + return await self._handle_route_get(request) + else: + return web.Response( + text="Method not allowed", status=405, content_type="application/json" + ) + + async def _handle_route_put(self, request: web.Request): + data = await request.json() + role = data["role"] + rank_ip = data["rank_ip"] + rank_port = int(data["rank_port"]) + engine_rank = int(data["engine_rank"]) + agent_name = data["agent_name"] + + # Add lock to make sure thread-safe + if role == "Prefill": + self.prefill_port_table[engine_rank] = { + "rank_ip": rank_ip, + "rank_port": rank_port, + "agent_name": agent_name, + } + logger.info( + f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}" + ) + + return web.Response(text="OK", status=200) + + async def _handle_route_get(self, request: web.Request): + engine_rank = request.query.get("engine_rank") + if not engine_rank: + return web.Response(text="Missing rank", status=400) + + # Find corresponding prefill info + async with self.lock: + bootstrap_info = self.prefill_port_table.get(int(engine_rank)) + if bootstrap_info is not None: + return web.json_response(bootstrap_info, status=200) + else: + return web.Response(text="Not Found", status=404) + + def _run_server(self): + try: + # Event Loop + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._runner = web.AppRunner(self.app) + self._loop.run_until_complete(self._runner.setup()) + + site = web.TCPSite(self._runner, port=self.port) + self._loop.run_until_complete(site.start()) + self._loop.run_forever() + except Exception as e: + logger.error(f"Server error: {str(e)}") + finally: + # Cleanup + self._loop.run_until_complete(self._runner.cleanup()) + self._loop.close() + + def close(self): + """Shutdown""" + if self._loop is not None and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + logger.info("Stopping server loop...") + + if self.thread.is_alive(): + self.thread.join(timeout=2) + logger.info("Server thread stopped") + + def poll(self) -> KVPoll: ... diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 4836c9a77..cb7654088 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -47,6 +47,7 @@ class ReqToMetadataIdxAllocator: class TransferBackend(Enum): MOONCAKE = "mooncake" + NIXL = "nixl" FAKE = "fake" @@ -73,6 +74,21 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, } return class_mapping.get(class_type) + if transfer_backend == TransferBackend.NIXL: + from sglang.srt.disaggregation.nixl import ( + NixlKVBootstrapServer, + NixlKVManager, + NixlKVReceiver, + NixlKVSender, + ) + + class_mapping = { + KVClassType.MANAGER: NixlKVManager, + KVClassType.SENDER: NixlKVSender, + KVClassType.RECEIVER: NixlKVReceiver, + KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, + } + return class_mapping.get(class_type) raise ValueError(f"Unsupported transfer backend: {transfer_backend}") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ee9250351..4fea1c824 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1211,6 +1211,7 @@ class ServerArgs: "--disaggregation-transfer-backend", type=str, default=ServerArgs.disaggregation_transfer_backend, + choices=["mooncake", "nixl"], help="The backend for disaggregation transfer. Default is mooncake.", ) parser.add_argument(