diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 10b6093b9..096a1db59 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import ( KVPoll, ) from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.dp_attention import ( + get_attention_dp_rank, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, get_free_port, - get_ip, - get_local_ip_by_remote, + get_local_ip_auto, is_valid_ipv6_address, maybe_wrap_ipv6_address, ) @@ -50,63 +56,44 @@ class CommonKVManager(BaseKVManager): self.bootstrap_host = server_args.host self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr - self.tp_size = server_args.tp_size - self.dp_size = server_args.dp_size - self.enable_dp_attention = server_args.enable_dp_attention - if not server_args.enable_dp_attention and server_args.dp_size != 1: - raise ValueError( - "If dp_attention is not enabled, dp size must be 1 in disaggregation mode." - ) - + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + self.attn_dp_size = get_attention_dp_size() + self.attn_dp_rank = get_attention_dp_rank() + self.system_dp_size = ( + 1 if server_args.enable_dp_attention else server_args.dp_size + ) + self.system_dp_rank = ( + self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0 + ) + self.pp_size = server_args.pp_size + self.pp_rank = self.kv_args.pp_rank self.rank_port = get_free_port() + self.local_ip = get_local_ip_auto() + self.server_socket = zmq.Context().socket(zmq.PULL) + if is_valid_ipv6_address(self.local_ip): + self.server_socket.setsockopt(zmq.IPV6, 1) + self.request_status: Dict[int, KVPoll] = {} + if self.disaggregation_mode == DisaggregationMode.PREFILL: self._register_to_bootstrap() + self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} + self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} + self.pp_group = get_pp_group() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} - self.prefill_tp_size_table: Dict[str, int] = {} + self.connection_lock = threading.Lock() + self.required_prefill_response_num_table: Dict[int, int] = {} + self.prefill_attn_tp_size_table: Dict[str, int] = {} self.prefill_dp_size_table: Dict[str, int] = {} + self.prefill_pp_size_table: Dict[str, int] = {} else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) - def _register_to_bootstrap(self): - """Register KVSender to bootstrap server via HTTP POST.""" - if self.dist_init_addr: - # multi node: bootstrap server's host is dist_init_addr - if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] - if self.dist_init_addr.endswith("]"): - host = self.dist_init_addr - else: - host, _ = self.dist_init_addr.rsplit(":", 1) - else: - host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) - else: - # single node: bootstrap server's host is same as http server's host - host = self.bootstrap_host - host = maybe_wrap_ipv6_address(host) - - bootstrap_server_url = f"{host}:{self.bootstrap_port}" - url = f"http://{bootstrap_server_url}/route" - payload = { - "role": "Prefill", - "tp_size": self.tp_size, - "dp_size": self.dp_size, - "rank_ip": get_local_ip_by_remote(), - "rank_port": self.rank_port, - "engine_rank": self.kv_args.engine_rank, - } - - try: - response = requests.put(url, json=payload) - if response.status_code == 200: - logger.debug("Prefill successfully registered to bootstrap server.") - else: - logger.error( - f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" - ) - except Exception as e: - logger.error(f"Prefill Failed to register to bootstrap server: {e}") + def _bind_server_socket(self): + self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) @cache def _connect(self, endpoint: str, is_ipv6: bool = False): @@ -116,6 +103,94 @@ class CommonKVManager(BaseKVManager): socket.connect(endpoint) return socket + def _register_to_bootstrap(self): + """Register KVSender to bootstrap server via HTTP POST.""" + if self.dist_init_addr: + # Multi-node case: bootstrap server's host is dist_init_addr + if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] + if self.dist_init_addr.endswith("]"): + host = self.dist_init_addr + else: + host, _ = self.dist_init_addr.rsplit(":", 1) + else: + host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) + else: + # Single-node case: bootstrap server's host is the same as http server's host + host = self.bootstrap_host + host = maybe_wrap_ipv6_address(host) + + bootstrap_server_url = f"{host}:{self.bootstrap_port}" + url = f"http://{bootstrap_server_url}/route" + payload = { + "role": "Prefill", + "attn_tp_size": self.attn_tp_size, + "attn_tp_rank": self.attn_tp_rank, + "attn_dp_size": self.attn_dp_size, + "attn_dp_rank": self.attn_dp_rank, + "pp_size": self.pp_size, + "pp_rank": self.pp_rank, + "system_dp_size": self.system_dp_size, + "system_dp_rank": self.system_dp_rank, + "rank_ip": self.local_ip, + "rank_port": self.rank_port, + } + + try: + response = requests.put(url, json=payload, timeout=5) + if response.status_code == 200: + logger.debug("Prefill successfully registered to bootstrap server.") + else: + logger.error( + f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}" + ) + except Exception as e: + logger.error( + f"Prefill instance failed to register to bootstrap server: {e}" + ) + + @cache + def _connect(self, endpoint: str, is_ipv6: bool = False): + socket = zmq.Context().socket(zmq.PUSH) + if is_ipv6: + socket.setsockopt(zmq.IPV6, 1) + socket.connect(endpoint) + return socket + + +class CommonKVSender(BaseKVSender): + + def __init__( + self, + mgr: BaseKVManager, + bootstrap_addr: str, + bootstrap_room: int, + dest_tp_ranks: List[int], + pp_rank: int, + ): + self.kv_mgr = mgr + self.bootstrap_room = bootstrap_room + self.aux_index = None + self.bootstrap_server_url = bootstrap_addr + # inner state + self.curr_idx = 0 + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + + def init(self, num_kv_indices: int, aux_index: Optional[int] = None): + self.num_kv_indices = num_kv_indices + self.aux_index = aux_index + + def send( + self, + kv_indices: npt.NDArray[np.int32], + ): + pass + + def poll(self) -> KVPoll: + pass + + def failure_exception(self): + raise Exception("Fake KVReceiver Exception") + class CommonKVReceiver(BaseKVReceiver): _ctx = zmq.Context() @@ -133,61 +208,88 @@ class CommonKVReceiver(BaseKVReceiver): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: - self.prefill_tp_size, self.prefill_dp_size = ( - self._get_prefill_dp_size_from_server() - ) - if self.prefill_tp_size is None or self.prefill_dp_size is None: - logger.error( - f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}" + ( + self.prefill_attn_tp_size, + self.prefill_dp_size, + self.prefill_pp_size, + ) = self._get_prefill_parallel_info_from_server() + if ( + self.prefill_attn_tp_size is None + or self.prefill_dp_size is None + or self.prefill_pp_size is None + ): + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return else: - self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( - self.prefill_tp_size + logger.debug( + f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}" + ) + self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = ( + self.prefill_attn_tp_size ) self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( self.prefill_dp_size ) + self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = ( + self.prefill_pp_size + ) else: - self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[ + self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[ self.bootstrap_addr ] self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ self.bootstrap_addr ] + self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[ + self.bootstrap_addr + ] # Currently, we don't allow prefill instance and decode instance to # have different TP sizes per DP rank, except for models using MLA. - local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size - prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size - if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank: + if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size: self.target_tp_rank = ( - self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size ) self.required_dst_info_num = 1 + self.required_prefill_response_num = 1 * ( + self.prefill_pp_size // self.kv_mgr.pp_size + ) self.target_tp_ranks = [self.target_tp_rank] - elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: + elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size: + if not self.kv_mgr.is_mla_backend: + logger.warning_once( + "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " + ) self.target_tp_rank = ( - self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank - ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) + self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size) self.required_dst_info_num = ( - local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank + self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size + ) + self.required_prefill_response_num = 1 * ( + self.prefill_pp_size // self.kv_mgr.pp_size ) self.target_tp_ranks = [self.target_tp_rank] else: - assert ( - self.kv_mgr.is_mla_backend - ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" - + if not self.kv_mgr.is_mla_backend: + logger.warning_once( + "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " + ) # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; self.target_tp_ranks = [ rank for rank in range( - (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank) - * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), - (self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1) - * (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), + (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size) + * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), + (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1) + * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), ) ] @@ -196,6 +298,14 @@ class CommonKVReceiver(BaseKVReceiver): # or the KVPoll will never be set correctly self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 + if self.kv_mgr.is_mla_backend: + self.required_prefill_response_num = ( + self.prefill_pp_size // self.kv_mgr.pp_size + ) + else: + self.required_prefill_response_num = ( + self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size + ) * (self.prefill_pp_size // self.kv_mgr.pp_size) if prefill_dp_rank is not None: logger.debug(f"Targeting DP rank: {prefill_dp_rank}") @@ -206,6 +316,9 @@ class CommonKVReceiver(BaseKVReceiver): # FIXME: alias here: target_dp_group -> prefill_dp_rank self.target_dp_group = self.prefill_dp_rank + self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( + self.required_prefill_response_num + ) # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" @@ -214,41 +327,49 @@ class CommonKVReceiver(BaseKVReceiver): if bootstrap_key not in self.kv_mgr.connection_pool: bootstrap_infos = [] for target_tp_rank in self.target_tp_ranks: - bootstrap_info = self._get_bootstrap_info_from_server( - target_tp_rank, - self.target_dp_group, - ) - if bootstrap_info is not None: - # NOTE: only support MLA for now: select one prefill rank as real rank - bootstrap_info["is_dummy"] = not bool( - target_tp_rank == self.target_tp_rank - or self.target_tp_rank is None + for target_pp_rank in range(self.prefill_pp_size): + bootstrap_info = self._get_bootstrap_info_from_server( + target_tp_rank, self.target_dp_group, target_pp_rank ) - bootstrap_infos.append(bootstrap_info) - else: - logger.error( - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}" - ) - self.bootstrap_infos = bootstrap_infos + if bootstrap_info is not None: + if self.kv_mgr.is_mla_backend: + # For MLA: target_tp_rank is the selected real rank, others are dummy ranks + bootstrap_info["is_dummy"] = not bool( + target_tp_rank == self.target_tp_rank + or self.target_tp_rank is None + ) + else: + # For non-MLA: all target_tp_ranks are selected real ranks + bootstrap_info["is_dummy"] = False + logger.debug( + f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}" + ) + bootstrap_infos.append(bootstrap_info) + else: + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}", + ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return - if len(self.bootstrap_infos) == 0: - logger.error( - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" - ) - else: - self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos - # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server - self._register_kv_args() + self.bootstrap_infos = bootstrap_infos + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos + + # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server + self._register_kv_args() else: self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] assert len(self.bootstrap_infos) > 0 - def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): + def _get_bootstrap_info_from_server( + self, engine_rank, target_dp_group, target_pp_rank + ): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" - response = requests.get(url) + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}" + response = requests.get(url, timeout=5) if response.status_code == 200: bootstrap_info = response.json() return bootstrap_info @@ -261,24 +382,28 @@ class CommonKVReceiver(BaseKVReceiver): logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - def _get_prefill_dp_size_from_server(self) -> int: + def _get_prefill_parallel_info_from_server( + self, + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: """Fetch the prefill parallel info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" + url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" response = requests.get(url) if response.status_code == 200: prefill_parallel_info = response.json() - return int(prefill_parallel_info["prefill_tp_size"]), int( - prefill_parallel_info["prefill_dp_size"] + return ( + int(prefill_parallel_info["prefill_attn_tp_size"]), + int(prefill_parallel_info["prefill_dp_size"]), + int(prefill_parallel_info["prefill_pp_size"]), ) else: logger.error( f"Failed to get prefill parallel info: {response.status_code}, {response.text}" ) - return None + return None, None, None except Exception as e: logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") - return None + return None, None, None @classmethod def _connect(cls, endpoint: str, is_ipv6: bool = False): @@ -317,10 +442,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): self.store = dict() self.lock = asyncio.Lock() self._setup_routes() - self.tp_size = None + self.pp_size = None + self.attn_tp_size = None self.dp_size = None - self.tp_size_per_dp_rank = None - self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {} + self.prefill_port_table: Dict[ + int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] + ] = {} # Start bootstrap server self.thread = threading.Thread(target=self._run_server, daemon=True) @@ -331,6 +458,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): def _setup_routes(self): self.app.router.add_route("*", "/route", self._handle_route) + self.app.router.add_get("/health", self._handle_health_check) + + async def _handle_health_check(self, request): + return web.Response(text="OK", status=200) async def _handle_route(self, request: web.Request): method = request.method @@ -346,37 +477,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): async def _handle_route_put(self, request: web.Request): data = await request.json() role = data["role"] - tp_size = data["tp_size"] - dp_size = data["dp_size"] + attn_tp_size = data["attn_tp_size"] + attn_tp_rank = data["attn_tp_rank"] + attn_dp_size = data["attn_dp_size"] + attn_dp_rank = data["attn_dp_rank"] + pp_size = data["pp_size"] + pp_rank = data["pp_rank"] + system_dp_size = data["system_dp_size"] + system_dp_rank = data["system_dp_rank"] rank_ip = data["rank_ip"] rank_port = int(data["rank_port"]) - engine_rank = int(data["engine_rank"]) - if self.tp_size is None: - self.tp_size = tp_size + if self.attn_tp_size is None: + self.attn_tp_size = attn_tp_size if self.dp_size is None: - self.dp_size = dp_size + self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size - tp_size_per_dp_rank = tp_size // dp_size - if self.tp_size_per_dp_rank == None: - self.tp_size_per_dp_rank = tp_size_per_dp_rank + if self.pp_size is None: + self.pp_size = pp_size - # Add lock to make sure thread-safe if role == "Prefill": - dp_group = engine_rank // tp_size_per_dp_rank - tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank + if system_dp_size == 1: + dp_group = attn_dp_rank + else: + dp_group = system_dp_rank + # Add lock to make sure thread-safe async with self.lock: if dp_group not in self.prefill_port_table: self.prefill_port_table[dp_group] = {} + if attn_tp_rank not in self.prefill_port_table[dp_group]: + self.prefill_port_table[dp_group][attn_tp_rank] = {} - self.prefill_port_table[dp_group][tp_rank_in_dp_group] = { + self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = { "rank_ip": rank_ip, "rank_port": rank_port, } logger.debug( - f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) @@ -384,14 +523,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): async def _handle_route_get(self, request: web.Request): engine_rank = request.query.get("engine_rank") target_dp_group = request.query.get("target_dp_group") - if not engine_rank or not target_dp_group: + target_pp_rank = request.query.get("target_pp_rank") + if not engine_rank or not target_dp_group or not target_pp_rank: return web.Response(text="Missing inputs for bootstrap server.", status=400) # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size - if int(engine_rank) == -1 and int(target_dp_group) == -1: + if ( + int(engine_rank) == -1 + and int(target_dp_group) == -1 + and int(target_pp_rank) == -1 + ): prefill_parallel_info = { - "prefill_tp_size": self.tp_size, + "prefill_attn_tp_size": self.attn_tp_size, "prefill_dp_size": self.dp_size, + "prefill_pp_size": self.pp_size, } return web.json_response(prefill_parallel_info, status=200) @@ -399,7 +544,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): async with self.lock: bootstrap_info = self.prefill_port_table[int(target_dp_group)][ int(engine_rank) - ] + ][int(target_pp_rank)] if bootstrap_info is not None: return web.json_response(bootstrap_info, status=200) @@ -412,7 +557,11 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self._runner = web.AppRunner(self.app) + access_log = None + if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG: + access_log = self.app.logger + + self._runner = web.AppRunner(self.app, access_log=access_log) self._loop.run_until_complete(self._runner.setup()) site = web.TCPSite(self._runner, host=self.host, port=self.port) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index f69d29622..f779e1fee 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1,33 +1,27 @@ from __future__ import annotations -import asyncio import concurrent.futures import ctypes import dataclasses import logging import os -import queue -import socket import struct import threading import time from collections import defaultdict -from functools import cache -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple import numpy as np import numpy.typing as npt import requests import zmq -from aiohttp import web -from sglang.srt.disaggregation.base.conn import ( - BaseKVBootstrapServer, - BaseKVManager, - BaseKVReceiver, - BaseKVSender, - KVArgs, - KVPoll, +from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll +from sglang.srt.disaggregation.common.conn import ( + CommonKVBootstrapServer, + CommonKVManager, + CommonKVReceiver, + CommonKVSender, ) from sglang.srt.disaggregation.common.utils import ( FastQueue, @@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import ( ) from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode -from sglang.srt.distributed import get_pp_group -from sglang.srt.layers.dp_attention import ( - get_attention_dp_rank, - get_attention_dp_size, - get_attention_tp_rank, - get_attention_tp_size, -) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, get_bool_env_var, - get_free_port, get_int_env_var, - get_ip, - get_local_ip_auto, is_valid_ipv6_address, - maybe_wrap_ipv6_address, ) logger = logging.getLogger(__name__) @@ -159,7 +142,7 @@ class AuxDataCodec: return -class MooncakeKVManager(BaseKVManager): +class MooncakeKVManager(CommonKVManager): AUX_DATA_HEADER = b"AUX_DATA" def __init__( @@ -169,43 +152,14 @@ class MooncakeKVManager(BaseKVManager): server_args: ServerArgs, is_mla_backend: Optional[bool] = False, ): - self.kv_args = args - self.local_ip = get_local_ip_auto() - self.is_mla_backend = is_mla_backend - self.disaggregation_mode = disaggregation_mode + super().__init__(args, disaggregation_mode, server_args, is_mla_backend) self.init_engine() - # for p/d multi node infer - self.bootstrap_host = server_args.host - self.bootstrap_port = server_args.disaggregation_bootstrap_port - self.dist_init_addr = server_args.dist_init_addr - self.attn_tp_size = get_attention_tp_size() - self.attn_tp_rank = get_attention_tp_rank() - self.attn_dp_size = get_attention_dp_size() - self.attn_dp_rank = get_attention_dp_rank() - self.system_dp_size = ( - 1 if server_args.enable_dp_attention else server_args.dp_size - ) - self.system_dp_rank = ( - self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0 - ) - self.pp_size = server_args.pp_size - self.pp_rank = self.kv_args.pp_rank - self.request_status: Dict[int, KVPoll] = {} - self.rank_port = None - self.server_socket = zmq.Context().socket(zmq.PULL) - if is_valid_ipv6_address(self.local_ip): - self.server_socket.setsockopt(zmq.IPV6, 1) - self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} - self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.start_prefill_thread() - self._register_to_bootstrap() self.session_failures = defaultdict(int) self.failed_sessions = set() self.session_lock = threading.Lock() - self.pp_group = get_pp_group() # Determine the number of threads to use for kv sender cpu_count = os.cpu_count() transfer_thread_pool_size = get_int_env_var( @@ -245,8 +199,6 @@ class MooncakeKVManager(BaseKVManager): self.session_pool = defaultdict(requests.Session) self.session_pool_lock = threading.Lock() self.addr_to_rooms_tracker = defaultdict(set) - self.connection_lock = threading.Lock() - self.required_prefill_response_num_table: Dict[int, int] = {} self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set) # Heartbeat interval should be at least 2 seconds self.heartbeat_interval = max( @@ -257,20 +209,12 @@ class MooncakeKVManager(BaseKVManager): get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1 ) self.start_decode_thread() - self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} - self.prefill_attn_tp_size_table: Dict[str, int] = {} - self.prefill_dp_size_table: Dict[str, int] = {} - self.prefill_pp_size_table: Dict[str, int] = {} # If a timeout happens on the decode side, it means decode instances # fail to receive the KV Cache transfer done signal after bootstrapping. # These timeout requests should be aborted to release the tree cache. self.waiting_timeout = get_int_env_var( "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300 ) - else: - raise ValueError( - f"Unsupported DisaggregationMode: {self.disaggregation_mode}" - ) self.failure_records: Dict[int, str] = {} self.failure_lock = threading.Lock() @@ -295,14 +239,6 @@ class MooncakeKVManager(BaseKVManager): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ) - @cache - def _connect(self, endpoint: str, is_ipv6: bool = False): - socket = zmq.Context().socket(zmq.PUSH) - if is_ipv6: - socket.setsockopt(zmq.IPV6, 1) - socket.connect(endpoint) - return socket - def _transfer_data(self, mooncake_session_id, transfer_blocks): if not transfer_blocks: return 0 @@ -654,6 +590,26 @@ class MooncakeKVManager(BaseKVManager): ] ) + def _handle_aux_data(self, msg: List[bytes]): + """Handle AUX_DATA messages received by the decode thread.""" + room = int(msg[1].decode("ascii")) + buffer_index = int(msg[2].decode("ascii")) + aux_index = int(msg[3].decode("ascii")) + data_length = struct.unpack(">I", msg[4])[0] + data = msg[5] + + if len(data) != data_length: + logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}") + return + + AuxDataCodec.deserialize_data_to_buffer( + self.kv_args, buffer_index, aux_index, data + ) + + logger.debug( + f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}" + ) + def sync_status_to_decode_endpoint( self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int ): @@ -802,11 +758,7 @@ class MooncakeKVManager(BaseKVManager): f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." ) - def _bind_server_socket(self): - self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) - def start_prefill_thread(self): - self.rank_port = get_free_port() self._bind_server_socket() def bootstrap_thread(): @@ -844,28 +796,7 @@ class MooncakeKVManager(BaseKVManager): threading.Thread(target=bootstrap_thread).start() - def _handle_aux_data(self, msg: List[bytes]): - """Handle AUX_DATA messages received by the decode thread.""" - room = int(msg[1].decode("ascii")) - buffer_index = int(msg[2].decode("ascii")) - aux_index = int(msg[3].decode("ascii")) - data_length = struct.unpack(">I", msg[4])[0] - data = msg[5] - - if len(data) != data_length: - logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}") - return - - AuxDataCodec.deserialize_data_to_buffer( - self.kv_args, buffer_index, aux_index, data - ) - - logger.debug( - f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}" - ) - def start_decode_thread(self): - self.rank_port = get_free_port() self._bind_server_socket() def decode_thread(): @@ -1020,51 +951,6 @@ class MooncakeKVManager(BaseKVManager): def get_session_id(self): return self.engine.get_session_id() - def _register_to_bootstrap(self): - """Register KVSender to bootstrap server via HTTP POST.""" - if self.dist_init_addr: - # multi node case: bootstrap server's host is dist_init_addr - if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6] - if self.dist_init_addr.endswith("]"): - host = self.dist_init_addr - else: - host, _ = self.dist_init_addr.rsplit(":", 1) - else: - host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0]) - else: - # single node case: bootstrap server's host is same as http server's host - host = self.bootstrap_host - host = maybe_wrap_ipv6_address(host) - - bootstrap_server_url = f"{host}:{self.bootstrap_port}" - url = f"http://{bootstrap_server_url}/route" - payload = { - "role": "Prefill", - "attn_tp_size": self.attn_tp_size, - "attn_tp_rank": self.attn_tp_rank, - "attn_dp_size": self.attn_dp_size, - "attn_dp_rank": self.attn_dp_rank, - "pp_size": self.pp_size, - "pp_rank": self.pp_rank, - "system_dp_size": self.system_dp_size, - "system_dp_rank": self.system_dp_rank, - "rank_ip": self.local_ip, - "rank_port": self.rank_port, - } - - try: - response = requests.put(url, json=payload, timeout=5) - if response.status_code == 200: - logger.debug("Prefill successfully registered to bootstrap server.") - else: - logger.error( - f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}" - ) - except Exception as e: - logger.error( - f"Prefill instance failed to register to bootstrap server: {e}" - ) - def _handle_node_failure(self, failed_bootstrap_addr): with self.connection_lock: keys_to_remove = [ @@ -1103,7 +989,7 @@ class MooncakeKVManager(BaseKVManager): ) -class MooncakeKVSender(BaseKVSender): +class MooncakeKVSender(CommonKVSender): def __init__( self, @@ -1113,19 +999,9 @@ class MooncakeKVSender(BaseKVSender): dest_tp_ranks: List[int], pp_rank: int, ): - self.kv_mgr = mgr - self.bootstrap_room = bootstrap_room - self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) - self.aux_index = None - self.bootstrap_server_url = bootstrap_addr + super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) self.conclude_state = None self.init_time = time.time() - # inner state - self.curr_idx = 0 - - def init(self, num_kv_indices: int, aux_index: Optional[int] = None): - self.num_kv_indices = num_kv_indices - self.aux_index = aux_index def send( self, @@ -1203,7 +1079,7 @@ class MooncakeKVSender(BaseKVSender): self.conclude_state = KVPoll.Failed -class MooncakeKVReceiver(BaseKVReceiver): +class MooncakeKVReceiver(CommonKVReceiver): _ctx = zmq.Context() _socket_cache = {} _socket_locks = {} @@ -1216,166 +1092,11 @@ class MooncakeKVReceiver(BaseKVReceiver): bootstrap_room: Optional[int] = None, prefill_dp_rank: Optional[int] = None, ): - self.bootstrap_room = bootstrap_room - self.bootstrap_addr = bootstrap_addr - self.kv_mgr = mgr - self.session_id = self.kv_mgr.get_session_id() - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + self.session_id = mgr.get_session_id() self.conclude_state = None self.init_time = None + super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) - if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: - ( - self.prefill_attn_tp_size, - self.prefill_dp_size, - self.prefill_pp_size, - ) = self._get_prefill_parallel_info_from_server() - if ( - self.prefill_attn_tp_size is None - or self.prefill_dp_size is None - or self.prefill_pp_size is None - ): - self.kv_mgr.record_failure( - self.bootstrap_room, - f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", - ) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - return - else: - logger.debug( - f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}" - ) - self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = ( - self.prefill_attn_tp_size - ) - self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( - self.prefill_dp_size - ) - self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = ( - self.prefill_pp_size - ) - else: - self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[ - self.bootstrap_addr - ] - self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ - self.bootstrap_addr - ] - self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[ - self.bootstrap_addr - ] - - # Currently, we don't allow prefill instance and decode instance to - # have different TP sizes per DP rank, except for models using MLA. - if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size: - self.target_tp_rank = ( - self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size - ) - self.required_dst_info_num = 1 - self.required_prefill_response_num = 1 * ( - self.prefill_pp_size // self.kv_mgr.pp_size - ) - self.target_tp_ranks = [self.target_tp_rank] - elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size: - if not self.kv_mgr.is_mla_backend: - logger.warning_once( - "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " - ) - self.target_tp_rank = ( - self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size - ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size) - self.required_dst_info_num = ( - self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size - ) - self.required_prefill_response_num = 1 * ( - self.prefill_pp_size // self.kv_mgr.pp_size - ) - self.target_tp_ranks = [self.target_tp_rank] - else: - if not self.kv_mgr.is_mla_backend: - logger.warning_once( - "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " - ) - # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; - self.target_tp_ranks = [ - rank - for rank in range( - (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size) - * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), - (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1) - * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size), - ) - ] - - # For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain - # multiple connections in the connection pool and have to send dummy requests to other prefill ranks, - # or the KVPoll will never be set correctly - self.target_tp_rank = self.target_tp_ranks[0] - self.required_dst_info_num = 1 - if self.kv_mgr.is_mla_backend: - self.required_prefill_response_num = ( - self.prefill_pp_size // self.kv_mgr.pp_size - ) - else: - self.required_prefill_response_num = ( - self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size - ) * (self.prefill_pp_size // self.kv_mgr.pp_size) - - if prefill_dp_rank is not None: - logger.debug(f"Targeting DP rank: {prefill_dp_rank}") - self.prefill_dp_rank = prefill_dp_rank - else: - self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size - - # FIXME: alias here: target_dp_group -> prefill_dp_rank - self.target_dp_group = self.prefill_dp_rank - - self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( - self.required_prefill_response_num - ) - # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank - bootstrap_key = ( - f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" - ) - - if bootstrap_key not in self.kv_mgr.connection_pool: - bootstrap_infos = [] - for target_tp_rank in self.target_tp_ranks: - for target_pp_rank in range(self.prefill_pp_size): - bootstrap_info = self._get_bootstrap_info_from_server( - target_tp_rank, self.target_dp_group, target_pp_rank - ) - if bootstrap_info is not None: - if self.kv_mgr.is_mla_backend: - # For MLA: target_tp_rank is the selected real rank, others are dummy ranks - bootstrap_info["is_dummy"] = not bool( - target_tp_rank == self.target_tp_rank - or self.target_tp_rank is None - ) - else: - # For non-MLA: all target_tp_ranks are selected real ranks - bootstrap_info["is_dummy"] = False - logger.debug( - f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}" - ) - bootstrap_infos.append(bootstrap_info) - else: - self.kv_mgr.record_failure( - self.bootstrap_room, - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}", - ) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - return - - self.bootstrap_infos = bootstrap_infos - self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos - - # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server - self._register_kv_args() - else: - self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] - - assert len(self.bootstrap_infos) > 0 self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) @@ -1398,29 +1119,6 @@ class MooncakeKVReceiver(BaseKVReceiver): logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - def _get_prefill_parallel_info_from_server( - self, - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - """Fetch the prefill parallel info from the bootstrap server.""" - try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}" - response = requests.get(url) - if response.status_code == 200: - prefill_parallel_info = response.json() - return ( - int(prefill_parallel_info["prefill_attn_tp_size"]), - int(prefill_parallel_info["prefill_dp_size"]), - int(prefill_parallel_info["prefill_pp_size"]), - ) - else: - logger.error( - f"Failed to get prefill parallel info: {response.status_code}, {response.text}" - ) - return None, None, None - except Exception as e: - logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") - return None, None, None - def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: packed_kv_data_ptrs = b"".join( @@ -1452,28 +1150,6 @@ class MooncakeKVReceiver(BaseKVReceiver): ] ) - @classmethod - def _connect(cls, endpoint: str, is_ipv6: bool = False): - with cls._global_lock: - if endpoint not in cls._socket_cache: - sock = cls._ctx.socket(zmq.PUSH) - if is_ipv6: - sock.setsockopt(zmq.IPV6, 1) - sock.connect(endpoint) - cls._socket_cache[endpoint] = sock - cls._socket_locks[endpoint] = threading.Lock() - return cls._socket_cache[endpoint], cls._socket_locks[endpoint] - - @classmethod - def _connect_to_bootstrap_server(cls, bootstrap_info: dict): - ip_address = bootstrap_info["rank_ip"] - port = bootstrap_info["rank_port"] - is_ipv6_address = is_valid_ipv6_address(ip_address) - sock, lock = cls._connect( - format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address - ) - return sock, lock - def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: sock, lock = self._connect_to_bootstrap_server(bootstrap_info) @@ -1551,154 +1227,5 @@ class MooncakeKVReceiver(BaseKVReceiver): self.conclude_state = KVPoll.Failed -class MooncakeKVBootstrapServer(BaseKVBootstrapServer): - def __init__(self, host: str, port: int): - self.host = host - self.port = port - self.app = web.Application() - self.store = dict() - self.lock = asyncio.Lock() - self._setup_routes() - self.pp_size = None - self.attn_tp_size = None - self.dp_size = None - self.prefill_port_table: Dict[ - int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] - ] = {} - - # Start bootstrap server - self.thread = threading.Thread(target=self._run_server, daemon=True) - self.run() - - def run(self): - self.thread.start() - - def _setup_routes(self): - self.app.router.add_route("*", "/route", self._handle_route) - self.app.router.add_get("/health", self._handle_health_check) - - async def _handle_health_check(self, request): - return web.Response(text="OK", status=200) - - async def _handle_route(self, request: web.Request): - method = request.method - if method == "PUT": - return await self._handle_route_put(request) - elif method == "GET": - return await self._handle_route_get(request) - else: - return web.Response( - text="Method not allowed", status=405, content_type="application/json" - ) - - async def _handle_route_put(self, request: web.Request): - data = await request.json() - role = data["role"] - attn_tp_size = data["attn_tp_size"] - attn_tp_rank = data["attn_tp_rank"] - attn_dp_size = data["attn_dp_size"] - attn_dp_rank = data["attn_dp_rank"] - pp_size = data["pp_size"] - pp_rank = data["pp_rank"] - system_dp_size = data["system_dp_size"] - system_dp_rank = data["system_dp_rank"] - rank_ip = data["rank_ip"] - rank_port = int(data["rank_port"]) - - if self.attn_tp_size is None: - self.attn_tp_size = attn_tp_size - - if self.dp_size is None: - self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size - - if self.pp_size is None: - self.pp_size = pp_size - - if role == "Prefill": - if system_dp_size == 1: - dp_group = attn_dp_rank - else: - dp_group = system_dp_rank - - # Add lock to make sure thread-safe - async with self.lock: - if dp_group not in self.prefill_port_table: - self.prefill_port_table[dp_group] = {} - if attn_tp_rank not in self.prefill_port_table[dp_group]: - self.prefill_port_table[dp_group][attn_tp_rank] = {} - - self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = { - "rank_ip": rank_ip, - "rank_port": rank_port, - } - logger.debug( - f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" - ) - - return web.Response(text="OK", status=200) - - async def _handle_route_get(self, request: web.Request): - engine_rank = request.query.get("engine_rank") - target_dp_group = request.query.get("target_dp_group") - target_pp_rank = request.query.get("target_pp_rank") - if not engine_rank or not target_dp_group or not target_pp_rank: - return web.Response(text="Missing inputs for bootstrap server.", status=400) - - # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size - if ( - int(engine_rank) == -1 - and int(target_dp_group) == -1 - and int(target_pp_rank) == -1 - ): - prefill_parallel_info = { - "prefill_attn_tp_size": self.attn_tp_size, - "prefill_dp_size": self.dp_size, - "prefill_pp_size": self.pp_size, - } - return web.json_response(prefill_parallel_info, status=200) - - # Find corresponding prefill info - async with self.lock: - bootstrap_info = self.prefill_port_table[int(target_dp_group)][ - int(engine_rank) - ][int(target_pp_rank)] - - if bootstrap_info is not None: - return web.json_response(bootstrap_info, status=200) - else: - return web.Response(text="Bootstrap info not Found", status=404) - - def _run_server(self): - try: - # Event Loop - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - access_log = None - if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG: - access_log = self.app.logger - - self._runner = web.AppRunner(self.app, access_log=access_log) - self._loop.run_until_complete(self._runner.setup()) - - site = web.TCPSite(self._runner, host=self.host, port=self.port) - self._loop.run_until_complete(site.start()) - self._loop.run_forever() - except Exception as e: - logger.error(f"Server error: {str(e)}") - finally: - # Cleanup - self._loop.run_until_complete(self._runner.cleanup()) - self._loop.close() - - def close(self): - """Shutdown""" - if self._loop is not None and self._loop.is_running(): - self._loop.call_soon_threadsafe(self._loop.stop) - logger.info("Stopping server loop...") - - if self.thread.is_alive(): - self.thread.join(timeout=2) - logger.info("Server thread stopped") - - def poll(self) -> KVPoll: ... +class MooncakeKVBootstrapServer(CommonKVBootstrapServer): + pass diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index c911319ea..871bdcbfc 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -1,37 +1,26 @@ from __future__ import annotations -import asyncio import dataclasses import logging -import queue -import socket import struct import threading import uuid from collections import defaultdict -from functools import cache -from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union +from typing import Dict, List, Optional, Set import numpy as np import numpy.typing as npt -import requests -import zmq -from aiohttp import web -from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll +from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll from sglang.srt.disaggregation.common.conn import ( CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver, + CommonKVSender, ) from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - format_tcp_address, - get_local_ip_auto, - is_valid_ipv6_address, -) logger = logging.getLogger(__name__) @@ -134,16 +123,9 @@ class NixlKVManager(CommonKVManager): "to run SGLang with NixlTransferEngine." ) from e self.agent = nixl_agent(str(uuid.uuid4())) - self.local_ip = get_local_ip_auto() - self.server_socket = zmq.Context().socket(zmq.PULL) - if is_valid_ipv6_address(self.local_ip): - self.server_socket.setsockopt(zmq.IPV6, 1) self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.request_status: Dict[int, KVPoll] = {} - self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} - self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( @@ -166,6 +148,9 @@ class NixlKVManager(CommonKVManager): self.request_status[bootstrap_room], status ) + def record_failure(self, bootstrap_room: int, failure_reason: str): + pass + def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( @@ -438,7 +423,7 @@ class NixlKVManager(CommonKVManager): notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - if decode_tp_size == self.tp_size: + if decode_tp_size == self.attn_tp_size: kv_xfer_handle = self.send_kvcache( req.agent_name, kv_indices, @@ -455,7 +440,7 @@ class NixlKVManager(CommonKVManager): chunked_dst_kv_indice, self.decode_kv_args_table[req.agent_name].gpu_id, notif, - prefill_tp_size=self.tp_size, + prefill_tp_size=self.attn_tp_size, decode_tp_size=decode_tp_size, decode_tp_rank=self.decode_kv_args_table[ req.agent_name @@ -505,9 +490,6 @@ class NixlKVManager(CommonKVManager): return False return self.transfer_statuses[room].is_done() - def _bind_server_socket(self): - self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) - def _start_bootstrap_thread(self): self._bind_server_socket() @@ -548,7 +530,7 @@ class NixlKVManager(CommonKVManager): threading.Thread(target=bootstrap_thread).start() -class NixlKVSender(BaseKVSender): +class NixlKVSender(CommonKVSender): def __init__( self, @@ -558,20 +540,10 @@ class NixlKVSender(BaseKVSender): dest_tp_ranks: List[int], pp_rank: int, ): - self.kv_mgr = mgr - self.bootstrap_room = bootstrap_room - self.aux_index = None - self.bootstrap_server_url = bootstrap_addr + super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) - # inner state - self.curr_idx = 0 - - def init(self, num_kv_indices: int, aux_index: Optional[int] = None): - self.num_kv_indices = num_kv_indices - self.aux_index = aux_index def send( self,