diff --git a/python/sglang/srt/disaggregation/conn.py b/python/sglang/srt/disaggregation/conn.py index 3989504ad..f61add327 100644 --- a/python/sglang/srt/disaggregation/conn.py +++ b/python/sglang/srt/disaggregation/conn.py @@ -1,15 +1,49 @@ from __future__ import annotations +import asyncio import logging -from enum import Enum -from typing import Optional +import struct +import threading +from functools import cache +from typing import Dict, List, Optional, Tuple import numpy as np import numpy.typing as npt +import zmq +from aiohttp import web + +from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine +from sglang.srt.disaggregation.utils import DisaggregationMode logger = logging.getLogger(__name__) +def group_concurrent_contiguous( + src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src_groups = [] + dst_groups = [] + current_src = [src_indices[0]] + current_dst = [dst_indices[0]] + + for i in range(1, len(src_indices)): + src_contiguous = src_indices[i] == src_indices[i - 1] + 1 + dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 + if src_contiguous and dst_contiguous: + current_src.append(src_indices[i]) + current_dst.append(dst_indices[i]) + else: + src_groups.append(current_src) + dst_groups.append(current_dst) + current_src = [src_indices[i]] + current_dst = [dst_indices[i]] + + src_groups.append(current_src) + dst_groups.append(current_dst) + + return src_groups, dst_groups + + class KVArgs: engine_rank: int kv_data_ptrs: list[int] @@ -21,10 +55,6 @@ class KVArgs: ib_device: str -class KVManager: - def __init__(self, args: KVArgs): ... - - class KVPoll: Failed = 0 Bootstrapping = 1 @@ -33,49 +63,434 @@ class KVPoll: Success = 4 +RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]] +WaitingPoolType = Dict[ + int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int] +] +KVSENDER_POLLING_PORT = 17788 +KVRECEIVER_POLLING_PORT = 27788 + + +class KVManager: + # TODO: make it general and support multiple transfer backend before merging + def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): + self.engine = MooncakeTransferEngine() + self.kv_args = args + self.disaggregation_mode = disaggregation_mode + self.request_pool: RequestPoolType = {} + self.request_status: Dict[int, KVPoll] = {} + self.server_socket = zmq.Context().socket(zmq.PULL) + self.register_buffer_to_engine() + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.waiting_pool: WaitingPoolType = {} + self.transfer_event = threading.Event() + self.start_prefill_thread() + elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.start_decode_thread() + else: + raise ValueError( + f"Unsupported DisaggregationMode: {self.disaggregation_mode}" + ) + + def register_buffer_to_engine(self): + for kv_data_ptr, kv_data_len in zip( + self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens + ): + self.engine.register(kv_data_ptr, kv_data_len) + + for aux_data_ptr, aux_data_len in zip( + self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens + ): + self.engine.register(aux_data_ptr, aux_data_len) + + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def send_kvcache( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int64], + dst_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int64], + ): + layer_num = int(len(self.kv_args.kv_data_ptrs) / 2) + prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( + prefill_kv_indices, dst_kv_indices + ) + for layer_id in range(layer_num): + prefill_key_layer_ptr = self.kv_args.kv_data_ptrs[layer_id] + key_item_len = self.kv_args.kv_item_lens[layer_id] + prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id] + value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id] + + decode_key_layer_ptr = dst_ptrs[layer_id] + decode_value_layer_ptr = dst_ptrs[layer_num + layer_id] + + for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): + prefill_key_addr = ( + prefill_key_layer_ptr + int(prefill_index[0]) * key_item_len + ) + decode_key_addr = ( + decode_key_layer_ptr + int(decode_index[0]) * key_item_len + ) + + # TODO: mooncake transfer engine can do async transfer. Do async later + status = self.engine.transfer_sync( + mooncake_session_id, + prefill_key_addr, + decode_key_addr, + key_item_len * len(prefill_index), + ) + if status != 0: + return status + + prefill_value_addr = ( + prefill_value_layer_ptr + int(prefill_index[0]) * value_item_len + ) + + decode_value_addr = ( + decode_value_layer_ptr + int(decode_index[0]) * value_item_len + ) + + # TODO: mooncake transfer engine can do async transfer. Do async later + status = self.engine.transfer_sync( + mooncake_session_id, + prefill_value_addr, + decode_value_addr, + value_item_len * len(prefill_index), + ) + if status != 0: + return status + return 0 + + def send_aux( + self, + mooncake_session_id: str, + prefill_aux_index: int, + dst_aux_ptrs: list[int], + dst_aux_index: int, + ): + aux_item_len = self.kv_args.aux_item_lens[0] + prefill_aux_addr = ( + self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len + ) + decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len + # TODO: mooncake transfer engine can do async transfer. Do async later + # Not sure about the amount of aux data, maybe transfer it by zmq is more effective + status = self.engine.transfer_sync( + mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len + ) + return status + + def sync_status_to_decode_endpoint(self, remote: str, room: int): + if ":" in remote: + remote = remote.split(":")[0] + self._connect( + "tcp://" + + remote + + ":" + + str(KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank) + ).send_multipart( + [ + str(room).encode("ascii"), + str(self.request_status[room]).encode("ascii"), + ] + ) + + def start_prefill_thread(self): + sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank + self.server_socket.bind("tcp://*:" + str(sender_rank_port)) + + def prefill_thread(): + while True: + ( + endpoint, + mooncake_session_id, + bootstrap_room, + dst_ptrs, + dst_kv_indices, + dst_aux_ptrs, + dst_aux_index, + ) = self.server_socket.recv_multipart() + if bootstrap_room.decode("ascii") == "None": + continue + endpoint = endpoint.decode("ascii") + mooncake_session_id = mooncake_session_id.decode("ascii") + bootstrap_room = int(bootstrap_room.decode("ascii")) + dst_ptrs = list(struct.unpack(f"{len(dst_ptrs)//8}Q", dst_ptrs)) + dst_kv_indices = np.frombuffer(dst_kv_indices, dtype=np.int64) + dst_aux_ptrs = list( + struct.unpack(f"{len(dst_aux_ptrs)//8}Q", dst_aux_ptrs) + ) + dst_aux_index = int(dst_aux_index.decode("ascii")) + self.waiting_pool[bootstrap_room] = ( + endpoint, + mooncake_session_id, + dst_ptrs, + dst_kv_indices, + dst_aux_ptrs, + dst_aux_index, + ) + self.transfer_event.set() + + threading.Thread(target=prefill_thread).start() + + def transfer_thread(): + while True: + self.transfer_event.wait() + self.transfer_event.clear() + bootstrap_room_ready = self.request_pool.keys() + bootstrap_room_request = self.waiting_pool.keys() + for room in list(bootstrap_room_request): + if room not in list(bootstrap_room_ready): + continue + status = KVPoll.Transferring + self.request_status[room] = status + ( + endpoint, + mooncake_session_id, + dst_ptrs, + dst_kv_indices, + dst_aux_ptrs, + dst_aux_index, + ) = self.waiting_pool.pop(room) + self.sync_status_to_decode_endpoint(endpoint, room) + ( + prefill_kv_indices, + prefill_aux_index, + ) = self.request_pool.pop(room) + ret = self.send_kvcache( + mooncake_session_id, + prefill_kv_indices, + dst_ptrs, + dst_kv_indices, + ) + if ret != 0: + status = KVPoll.Failed + self.sync_status_to_decode_endpoint(endpoint, room) + continue + ret = self.send_aux( + mooncake_session_id, + prefill_aux_index, + dst_aux_ptrs, + dst_aux_index, + ) + if ret != 0: + status = KVPoll.Failed + else: + status = KVPoll.Success + self.request_status[room] = status + self.sync_status_to_decode_endpoint(endpoint, room) + + threading.Thread(target=transfer_thread).start() + + def start_decode_thread(self): + receiver_rank_port = KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank + self.server_socket.bind("tcp://*:" + str(receiver_rank_port)) + + def decode_thread(): + while True: + (bootstrap_room, status) = self.server_socket.recv_multipart() + status = int(status.decode("ascii")) + bootstrap_room = int(bootstrap_room.decode("ascii")) + self.request_status[bootstrap_room] = status + + threading.Thread(target=decode_thread).start() + + def enqueue_request( + self, + bootstrap_room: int, + kv_indices: npt.NDArray[np.int64], + aux_index: Optional[int], + ): + self.request_pool[bootstrap_room] = (kv_indices, aux_index) + self.request_status[bootstrap_room] = KVPoll.WaitingForInput + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.transfer_event.set() + + def check_status(self, bootstrap_room: int): + if ( + self.disaggregation_mode == DisaggregationMode.DECODE + and self.request_status[bootstrap_room] == KVPoll.Success + ): + if bootstrap_room in self.request_pool: + self.request_pool.pop(bootstrap_room) + + return self.request_status[bootstrap_room] + + def set_status(self, bootstrap_room: int, status: KVPoll): + self.request_status[bootstrap_room] = status + + def get_localhost(self): + return self.engine.get_localhost() + + def get_session_id(self): + return self.engine.get_session_id() + + class KVSender: + def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): - self.has_sent = False + self.kv_mgr = mgr + self.bootstrap_room = bootstrap_room + self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput) + self.aux_index = None - def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ... + def init(self, num_kv_indices: int, aux_index: Optional[int] = None): + self.aux_index = aux_index + self.num_kv_indices = num_kv_indices - def send(self, kv_indices: npt.NDArray[np.int32]): - self.has_sent = True + def send(self, kv_indices: npt.NDArray[np.int64]): + self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, self.aux_index) def poll(self) -> KVPoll: - if self.has_sent is False: - # Assume handshake completed instantly - return KVPoll.WaitingForInput - else: - # Assume transfer completed instantly - return KVPoll.Success + return self.kv_mgr.check_status(self.bootstrap_room) def failure_exception(self): raise Exception("Fake KVSender Exception") class KVReceiver: + def __init__( self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None ): - self.has_init = False + self.bootstrap_room = bootstrap_room + self.bootstrap_addr = bootstrap_addr + self.kv_mgr = mgr + self.prefill_server_url = ( + bootstrap_addr.split(":")[0] + + ":" + + str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank) + ) + self.decode_ip = self.kv_mgr.get_localhost() + self.session_id = self.kv_mgr.get_session_id() + self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput) - def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): - self.has_init = True + @cache + def _connect(self, endpoint: str): + socket = zmq.Context().socket(zmq.PUSH) + socket.connect(endpoint) + return socket + + def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): + self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, aux_index) + packed_kv_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs + ) + packed_aux_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs + ) + self._connect("tcp://" + self.prefill_server_url).send_multipart( + [ + self.decode_ip.encode("ascii"), + self.session_id.encode("ascii"), + str(self.bootstrap_room).encode("ascii"), + packed_kv_data_ptrs, + kv_indices.tobytes(), + packed_aux_data_ptrs, + str(aux_index).encode("ascii"), + ] + ) def poll(self) -> KVPoll: - if self.has_init is False: - # Assume handshake completed instantly - return KVPoll.WaitingForInput - else: - # Assume transfer completed instantly - return KVPoll.Success + return self.kv_mgr.check_status(self.bootstrap_room) def failure_exception(self): raise Exception("Fake KVReceiver Exception") class KVBootstrapServer: - def __init__(self, port: int): ... + def __init__(self, port: int): + self.port = port + self.app = web.Application() + self.store = dict() + self.lock = asyncio.Lock() + self._setup_routes() + + # Start bootstrap server + self.thread = threading.Thread(target=self._run_server, daemon=True) + self.run() + + def run(self): + self.thread.start() + + def _setup_routes(self): + self.app.router.add_route("*", "/metadata", self._handle_metadata) + + async def _handle_metadata(self, request: web.Request): + key = request.query.get("key", "") + + if request.method == "GET": + return await self._handle_get(key) + elif request.method == "PUT": + return await self._handle_put(key, request) + elif request.method == "DELETE": + return await self._handle_delete(key) + return web.Response( + text="Method not allowed", status=405, content_type="application/json" + ) + + async def _handle_get(self, key): + async with self.lock: + value = self.store.get(key) + if value is None: + return web.Response( + text="metadata not found", status=404, content_type="application/json" + ) + return web.Response(body=value, status=200, content_type="application/json") + + async def _handle_put(self, key, request): + data = await request.read() + async with self.lock: + self.store[key] = data + return web.Response( + text="metadata updated", status=200, content_type="application/json" + ) + + async def _handle_delete(self, key): + async with self.lock: + if key not in self.store: + return web.Response( + text="metadata not found", + status=404, + content_type="application/json", + ) + del self.store[key] + return web.Response( + text="metadata deleted", status=200, content_type="application/json" + ) + + def _run_server(self): + try: + # Event Loop + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._runner = web.AppRunner(self.app) + self._loop.run_until_complete(self._runner.setup()) + + site = web.TCPSite(self._runner, port=self.port) + self._loop.run_until_complete(site.start()) + self._loop.run_forever() + except Exception as e: + logger.error(f"Server error: {str(e)}") + finally: + # Cleanup + self._loop.run_until_complete(self._runner.cleanup()) + self._loop.close() + + def close(self): + """Shutdown""" + if self._loop is not None and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + logger.info("Stopping server loop...") + + if self.thread.is_alive(): + self.thread.join(timeout=2) + logger.info("Server thread stopped") def poll(self) -> KVPoll: ... diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index ece651b4e..e0918a083 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -24,11 +24,13 @@ import logging from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple +import numpy as np import torch from torch.distributed import ProcessGroup from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver from sglang.srt.disaggregation.utils import ( + DisaggregationMode, ReqToMetadataIdxAllocator, poll_and_all_reduce, ) @@ -115,7 +117,7 @@ class DecodePreallocQueue: metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" - kv_manager = KVManager(kv_args) + kv_manager = KVManager(kv_args, DisaggregationMode("decode")) return kv_manager def add(self, req: Req) -> None: @@ -186,6 +188,7 @@ class DecodePreallocQueue: ] .cpu() .numpy() + .astype(np.int64) ) decode_req.metadata_buffer_index = ( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index fad26c571..25ab54bb8 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -26,6 +26,7 @@ import torch from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender from sglang.srt.disaggregation.utils import ( + DisaggregationMode, ReqToMetadataIdxAllocator, poll_and_all_reduce, ) @@ -95,7 +96,7 @@ class PrefillBootstrapQueue: metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" - kv_manager = KVManager(kv_args) + kv_manager = KVManager(kv_args, DisaggregationMode("prefill")) return kv_manager def add(self, req: Req) -> None: diff --git a/python/sglang/srt/disaggregation/transfer_engine/mooncake.py b/python/sglang/srt/disaggregation/transfer_engine/mooncake.py new file mode 100644 index 000000000..bdba72579 --- /dev/null +++ b/python/sglang/srt/disaggregation/transfer_engine/mooncake.py @@ -0,0 +1,108 @@ +import json +import logging +import os +import uuid +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class MooncakeTransferEngineConfig: + local_hostname: str + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> "MooncakeTransferEngineConfig": + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + local_hostname=config.get("local_hostname", None), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "rdma"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> "MooncakeTransferEngineConfig": + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." + ) + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + + def __init__(self): + try: + from mooncake.engine import TransferEngine + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run SGLang with MooncakeTransferEngine." + ) from e + + self.engine = TransferEngine() + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + except ValueError as e: + logger.error(e) + raise + except Exception as exc: + logger.error("An error occurred while loading the configuration: %s", exc) + raise + + self.config = MooncakeTransferEngineConfig.load_from_env() + + session_suffix = "_" + str(uuid.uuid4()) + self.session_id = self.config.local_hostname + session_suffix + self.initialize( + self.session_id, + self.config.metadata_server, + self.config.protocol, + self.config.device_name, + ) + + def register(self, ptr, length): + self.engine.register_memory(ptr, length) + + def deregister(self, ptr): + self.engine.unregister_memory(ptr) + + def initialize( + self, + local_hostname: str, + metadata_server: str, + protocol: str, + device_name: str, + ) -> None: + """Initialize the mooncake instance.""" + self.engine.initialize(local_hostname, metadata_server, protocol, device_name) + + def transfer_sync( + self, session_id: str, buffer: int, peer_buffer_address: int, length: int + ) -> int: + """Synchronously transfer data to the specified address.""" + + ret = self.engine.transfer_sync_write( + session_id, buffer, peer_buffer_address, length + ) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + + def get_localhost(self): + return self.config.local_hostname + + def get_session_id(self): + return self.session_id diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 591b0660f..00affa0a4 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -95,6 +95,10 @@ class GenerateReqInput: # Whether to return hidden states return_hidden_states: bool = False + # For disaggregated inference + bootstrap_host: Optional[str] = None + bootstrap_room: Optional[int] = None + def normalize_batch_and_arguments(self): """ Normalize the batch size and arguments for the request. @@ -435,6 +439,10 @@ class TokenizedGenerateReqInput: # Whether to return hidden states return_hidden_states: bool = False + # For disaggregated inference + bootstrap_host: Optional[str] = None + bootstrap_room: Optional[int] = None + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bab64bae1..cce17729e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -390,6 +390,8 @@ class Req: custom_logit_processor: Optional[str] = None, return_hidden_states: bool = False, eos_token_ids: Optional[Set[int]] = None, + bootstrap_host: Optional[str] = None, + bootstrap_room: Optional[int] = None, ): # Input and output info self.rid = rid @@ -523,8 +525,8 @@ class Req: self.lora_path = lora_path # For disaggregation - self.bootstrap_host: str = "0.0.0.0" - self.bootstrap_room: Optional[int] = None + self.bootstrap_host: str = bootstrap_host + self.bootstrap_room: Optional[int] = bootstrap_room self.disagg_kv_sender: Optional[KVSender] = None # used for warmup because we don't have a pair yet when init diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cf3de55cc..383cd6809 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -836,6 +836,8 @@ class Scheduler( custom_logit_processor=custom_logit_processor, return_hidden_states=recv_req.return_hidden_states, eos_token_ids=self.model_config.hf_eos_token_id, + bootstrap_host=recv_req.bootstrap_host, + bootstrap_room=recv_req.bootstrap_room, ) req.tokenizer = self.tokenizer req.queue_time_start = time.time() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2b43af42f..33afffbd6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -452,6 +452,8 @@ class TokenizerManager: top_logprobs_num, token_ids_logprob, obj.stream, + bootstrap_host=obj.bootstrap_host, + bootstrap_room=obj.bootstrap_room, lora_path=obj.lora_path, input_embeds=input_embeds, session_params=session_params,