diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index bdf5f5027..42e0b2ae5 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.server_args import ServerArgs class KVArgs: @@ -16,6 +17,7 @@ class KVArgs: aux_data_lens: list[int] aux_item_lens: list[int] ib_device: str + gpu_id: int class KVPoll: @@ -30,7 +32,12 @@ class BaseKVManager(ABC): """Base class for managing transfers states""" @abstractmethod - def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ... + def __init__( + self, + args: KVArgs, + disaggregation_mode: DisaggregationMode, + server_args: ServerArgs, + ): ... class BaseKVSender(ABC): diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 1f4f9cfa7..699de5524 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -128,8 +128,11 @@ class DecodePreallocQueue: metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" + kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) - kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE) + kv_manager = kv_manager_class( + kv_args, DisaggregationMode.DECODE, self.scheduler.server_args + ) return kv_manager def add(self, req: Req) -> None: diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 957fdc559..062d43a59 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -2,10 +2,9 @@ from __future__ import annotations import asyncio import dataclasses -import json import logging import queue -import random +import socket import struct import threading from functools import cache @@ -27,24 +26,12 @@ from sglang.srt.disaggregation.base.conn import ( ) from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode -from sglang.srt.utils import is_port_available +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote logger = logging.getLogger(__name__) -def find_available_ports(base_port: int, count: int) -> List[int]: - """Find consecutive available ports starting from base_port.""" - available_ports = [] - current_port = base_port - - while len(available_ports) < count: - if is_port_available(current_port): - available_ports.append(current_port) - current_port += random.randint(100, 1000) - - return available_ports - - def group_concurrent_contiguous( src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: @@ -82,10 +69,10 @@ class TransferKVChunk: @dataclasses.dataclass class TransferInfo: - endpoint: str - decode_port: int - mooncake_session_id: str room: int + endpoint: str + dst_port: int + mooncake_session_id: str dst_kv_ptrs: list[int] dst_kv_indices: npt.NDArray[np.int64] dst_aux_ptrs: list[int] @@ -94,10 +81,10 @@ class TransferInfo: @classmethod def from_zmq(cls, msg: List[bytes]): return cls( - endpoint=msg[0].decode("ascii"), - decode_port=int(msg[1].decode("ascii")), - mooncake_session_id=msg[2].decode("ascii"), - room=int(msg[3].decode("ascii")), + room=int(msg[0].decode("ascii")), + endpoint=msg[1].decode("ascii"), + dst_port=int(msg[2].decode("ascii")), + mooncake_session_id=msg[3].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), @@ -106,12 +93,20 @@ class TransferInfo: class MooncakeKVManager(BaseKVManager): - def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): + def __init__( + self, + args: KVArgs, + disaggregation_mode: DisaggregationMode, + server_args: ServerArgs, + ): self.engine = MooncakeTransferEngine() self.kv_args = args self.disaggregation_mode = disaggregation_mode + # for p/d multi node infer + self.bootstrap_port = server_args.disaggregation_bootstrap_port + self.dist_init_addr = server_args.dist_init_addr self.request_status: Dict[int, KVPoll] = {} - self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {} + self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.rank_port = None self.server_socket = zmq.Context().socket(zmq.PULL) self.register_buffer_to_engine() @@ -119,6 +114,7 @@ class MooncakeKVManager(BaseKVManager): self.transfer_queue = queue.Queue() self.transfer_infos: Dict[int, TransferInfo] = {} self.start_prefill_thread() + self._register_to_bootstrap() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.start_decode_thread() else: @@ -150,54 +146,29 @@ class MooncakeKVManager(BaseKVManager): dst_kv_ptrs: list[int], dst_kv_indices: npt.NDArray[np.int64], ): - layer_num = int(len(self.kv_args.kv_data_ptrs) / 2) + # group by indices prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_indices, dst_kv_indices ) - for layer_id in range(layer_num): - prefill_key_layer_ptr = self.kv_args.kv_data_ptrs[layer_id] - key_item_len = self.kv_args.kv_item_lens[layer_id] - prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id] - value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id] - decode_key_layer_ptr = dst_kv_ptrs[layer_id] - decode_value_layer_ptr = dst_kv_ptrs[layer_num + layer_id] + num_layers = len(self.kv_args.kv_data_ptrs) + for layer_id in range(num_layers): + src_ptr = self.kv_args.kv_data_ptrs[layer_id] + dst_ptr = dst_kv_ptrs[layer_id] + item_len = self.kv_args.kv_item_lens[layer_id] for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): - prefill_key_addr = ( - prefill_key_layer_ptr + int(prefill_index[0]) * key_item_len - ) - decode_key_addr = ( - decode_key_layer_ptr + int(decode_index[0]) * key_item_len - ) + src_addr = src_ptr + int(prefill_index[0]) * item_len + dst_addr = dst_ptr + int(decode_index[0]) * item_len + length = item_len * len(prefill_index) - # TODO: mooncake transfer engine can do async transfer. Do async later + # TODO: make async later status = self.engine.transfer_sync( - mooncake_session_id, - prefill_key_addr, - decode_key_addr, - key_item_len * len(prefill_index), + mooncake_session_id, src_addr, dst_addr, length ) if status != 0: return status - prefill_value_addr = ( - prefill_value_layer_ptr + int(prefill_index[0]) * value_item_len - ) - - decode_value_addr = ( - decode_value_layer_ptr + int(decode_index[0]) * value_item_len - ) - - # TODO: mooncake transfer engine can do async transfer. Do async later - status = self.engine.transfer_sync( - mooncake_session_id, - prefill_value_addr, - decode_value_addr, - value_item_len * len(prefill_index), - ) - if status != 0: - return status return 0 def send_aux( @@ -230,16 +201,15 @@ class MooncakeKVManager(BaseKVManager): ) def start_prefill_thread(self): - # Find available port for prefill tp - self.rank_port = find_available_ports(20000, 1)[0] - self.server_socket.bind("tcp://*:" + str(self.rank_port)) + self.rank_port = get_free_port() + self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") def bootstrap_thread(): """This thread recvs pre-alloc notification from the decode engine""" # KVPoll.Bootstrapping -> KVPoll.WaitingForInput while True: waiting_req_bytes = self.server_socket.recv_multipart() - room = waiting_req_bytes[3].decode("ascii") + room = waiting_req_bytes[0].decode("ascii") if room == "None": continue room = int(room) @@ -295,8 +265,8 @@ class MooncakeKVManager(BaseKVManager): threading.Thread(target=transfer_thread).start() def start_decode_thread(self): - self.rank_port = find_available_ports(25000, 1)[0] - self.server_socket.bind("tcp://*:" + str(self.rank_port)) + self.rank_port = get_free_port() + self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") def decode_thread(): while True: @@ -343,12 +313,36 @@ class MooncakeKVManager(BaseKVManager): self.request_status[bootstrap_room], status ) - def get_localhost(self): - return self.engine.get_localhost() - def get_session_id(self): return self.engine.get_session_id() + def _register_to_bootstrap(self): + """Register KVSender to bootstrap server via HTTP POST.""" + if self.dist_init_addr: + ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) + else: + ip_address = get_ip() + + bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" + url = f"http://{bootstrap_server_url}/route" + payload = { + "role": "Prefill", + "rank_ip": get_local_ip_by_remote(), + "rank_port": self.rank_port, + "bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}", + } + + try: + response = requests.put(url, json=payload) + if response.status_code == 200: + logger.debug("Prefill successfully registered to bootstrap server.") + else: + logger.error( + f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" + ) + except Exception as e: + logger.error(f"Prefill Failed to register to bootstrap server: {e}") + class MooncakeKVSender(BaseKVSender): @@ -360,38 +354,8 @@ class MooncakeKVSender(BaseKVSender): self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) self.aux_index = None self.bootstrap_server_url = bootstrap_addr - self.session_id = self.kv_mgr.get_session_id() - # Register to bootstrap server - self._register_to_bootstrap() - - def _register_to_bootstrap(self): - """Register KVSender to bootstrap server via HTTP POST.""" - url = f"http://{self.bootstrap_server_url}/kv_route" - payload = { - "identity": self.session_id, - "role": "Prefill", - "serve_ip": self.kv_mgr.get_localhost(), - "serve_port": self.kv_mgr.rank_port, - "tp_rank": self.kv_mgr.kv_args.engine_rank, - } - - logger.info( - f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}" - ) - - try: - response = requests.put(url, json=payload) - if response.status_code == 200: - logger.info(f"Prefill successfully registered to bootstrap server.") - else: - logger.info( - f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}" - ) - except Exception as e: - logger.info(f"Prefill Failed to register to bootstrap server: {e}") - def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices self.aux_index = aux_index @@ -433,21 +397,35 @@ class MooncakeKVReceiver(BaseKVReceiver): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr - self.decode_ip = self.kv_mgr.get_localhost() self.session_id = self.kv_mgr.get_session_id() - self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) - self.prefill_engine_rank = None - self.decode_port = self.kv_mgr.rank_port - self.dealer_socket = None + self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) - def _get_prefill_info_from_bootstrap(self, tp_rank: int): - """Fetch the prefill server port corresponding to tp_rank from the bootstrap server.""" + self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" + + if self.bootstrap_key not in self.kv_mgr.connection_pool: + self.bootstrap_info = self._get_bootstrap_info_from_server( + self.bootstrap_key + ) + if self.bootstrap_info is None: + logger.error( + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) + else: + self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info + else: + self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key] + + assert self.bootstrap_info is not None + self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) + + def _get_bootstrap_info_from_server(self, bootstrap_key: str): + """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}" + url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}" response = requests.get(url) if response.status_code == 200: - prefill_info = response.json() - return prefill_info + bootstrap_info = response.json() + return bootstrap_info else: logger.error( f"Failed to get prefill server info: {response.status_code}, {response.text}" @@ -464,39 +442,13 @@ class MooncakeKVReceiver(BaseKVReceiver): return socket def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): - prefill_info = None - logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.") + self.prefill_server_url = ( + f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" + ) + logger.debug( + f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" + ) - if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool: - prefill_info = self._get_prefill_info_from_bootstrap( - self.kv_mgr.kv_args.engine_rank - ) - if prefill_info is None: - logger.error( - logger.error( - f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}" - ) - ) - else: - self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = ( - prefill_info - ) - else: - prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] - - if prefill_info: - self.prefill_server_url = ( - f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}" - ) - - logger.info( - f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}" - ) - self.handshake_prefill_server(kv_indices, aux_index) - - def handshake_prefill_server( - self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None - ): packed_kv_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs ) @@ -505,10 +457,10 @@ class MooncakeKVReceiver(BaseKVReceiver): ) self._connect("tcp://" + self.prefill_server_url).send_multipart( [ - self.decode_ip.encode("ascii"), - str(self.decode_port).encode("ascii"), - self.session_id.encode("ascii"), str(self.bootstrap_room).encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.session_id.encode("ascii"), packed_kv_data_ptrs, kv_indices.tobytes(), packed_aux_data_ptrs, @@ -530,10 +482,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self.store = dict() self.lock = asyncio.Lock() self._setup_routes() - # prefill_engine_rank -> prefill_info - self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {} - - self.context = zmq.Context() + self.prefill_port_table: Dict[str, Dict[str, Union[str, int]]] = {} self.prefill_engine_rank = None @@ -546,7 +495,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): def _setup_routes(self): self.app.router.add_route("*", "/metadata", self._handle_metadata) - self.app.router.add_route("*", "/kv_route", self._handle_kv_route) + self.app.router.add_route("*", "/route", self._handle_route) async def _handle_metadata(self, request: web.Request): key = request.query.get("key", "") @@ -591,54 +540,47 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): text="metadata deleted", status=200, content_type="application/json" ) - async def _handle_kv_route(self, request: web.Request): + async def _handle_route(self, request: web.Request): method = request.method if method == "PUT": - return await self._handle_kv_route_put(request) + return await self._handle_route_put(request) elif method == "GET": - return await self._handle_kv_route_get(request) + return await self._handle_route_get(request) else: return web.Response( text="Method not allowed", status=405, content_type="application/json" ) - async def _handle_kv_route_put(self, request: web.Request): + async def _handle_route_put(self, request: web.Request): data = await request.json() - identity = data["identity"] role = data["role"] - serve_ip = data["serve_ip"] - serve_port = int(data["serve_port"]) # Assuming serve_port is an integer - tp_rank = int(data["tp_rank"]) + rank_ip = data["rank_ip"] + rank_port = int(data["rank_port"]) + bootstrap_key = data["bootstrap_key"] # Add lock to make sure thread-safe if role == "Prefill": - async with self.lock: - self.prefill_port_table[tp_rank] = { - "serve_ip": serve_ip, - "serve_port": serve_port, - } - logger.info( - f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}" + self.prefill_port_table[bootstrap_key] = { + "rank_ip": rank_ip, + "rank_port": rank_port, + } + logger.debug( + f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) - async def _handle_kv_route_get(self, request: web.Request): - tp_rank = request.query.get("tp_rank") - if not tp_rank: - return web.Response(text="Missing tp_rank", status=400) - try: - tp_rank = int(tp_rank) - except ValueError: - return web.Response(text="tp_rank must be int", status=400) + async def _handle_route_get(self, request: web.Request): + bootstrap_key = request.query.get("bootstrap_key") + if not bootstrap_key: + return web.Response(text="Missing bootstrap_key", status=400) # Find corresponding prefill info async with self.lock: - prefill_info = self.prefill_port_table.get(tp_rank) - - if prefill_info is not None: - return web.json_response(prefill_info, status=200) + bootstrap_info = self.prefill_port_table.get(bootstrap_key) + if bootstrap_info is not None: + return web.json_response(bootstrap_info, status=200) else: return web.Response(text="Not Found", status=404) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 0c565b3a9..408402c4f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -67,6 +67,7 @@ class PrefillBootstrapQueue: bootstrap_port: int, gloo_group: ProcessGroup, transfer_backend: TransferBackend, + scheduler: Scheduler, ): self.token_to_kv_pool = token_to_kv_pool self.aux_dtype = aux_dtype @@ -76,6 +77,7 @@ class PrefillBootstrapQueue: self.tp_rank = tp_rank self.tp_size = tp_size self.transfer_backend = transfer_backend + self.scheduler = scheduler self.kv_manager = self._init_kv_manager() self.queue: List[Req] = [] self.gloo_group = gloo_group @@ -108,8 +110,11 @@ class PrefillBootstrapQueue: metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" + kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) - kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL) + kv_manager = kv_manager_class( + kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args + ) return kv_manager def add(self, req: Req) -> None: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 156146e83..87af63010 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -599,6 +599,7 @@ class Scheduler( bootstrap_port=self.server_args.disaggregation_bootstrap_port, gloo_group=self.tp_worker.get_attention_tp_cpu_group(), transfer_backend=self.transfer_backend, + scheduler=self, ) # The prefill requests that are in the middle of kv sending self.disagg_prefill_inflight_queue: List[Req] = [] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b2dca816a..0deea5be3 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1872,3 +1872,36 @@ def is_hopper_with_cuda_12_3(): cuda_version = torch.version.cuda.split(".") is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3 return is_hopper and is_cuda_compatible + + +def get_free_port(): + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_local_ip_by_remote() -> str: + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + raise ValueError(f"Can not get local ip")