From dca90f1db88b83a30726186215e4b8abe6799ace Mon Sep 17 00:00:00 2001 From: shangmingc Date: Sat, 19 Apr 2025 19:31:00 +0800 Subject: [PATCH] [PD] Remove the requirement of config file for mooncake backend (#5460) --- python/sglang/srt/disaggregation/decode.py | 2 +- .../srt/disaggregation/mooncake/conn.py | 50 +--------- .../mooncake/transfer_engine.py | 91 ++++++------------- python/sglang/srt/disaggregation/prefill.py | 2 +- python/sglang/srt/server_args.py | 7 ++ 5 files changed, 44 insertions(+), 108 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index d106e42d4..23acf5222 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -121,7 +121,7 @@ class DecodePreallocQueue: kv_args.aux_item_lens = [ metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] - kv_args.ib_device = "mock-ib-device" + kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager = kv_manager_class( diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index ef9e127c0..11b712acc 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager): disaggregation_mode: DisaggregationMode, server_args: ServerArgs, ): - self.engine = MooncakeTransferEngine() self.kv_args = args + self.engine = MooncakeTransferEngine( + hostname=get_local_ip_by_remote(), + gpu_id=self.kv_args.gpu_id, + ib_device=self.kv_args.ib_device, + ) self.disaggregation_mode = disaggregation_mode # for p/d multi node infer self.bootstrap_port = server_args.disaggregation_bootstrap_port @@ -503,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self.thread.start() def _setup_routes(self): - self.app.router.add_route("*", "/metadata", self._handle_metadata) self.app.router.add_route("*", "/route", self._handle_route) - async def _handle_metadata(self, request: web.Request): - key = request.query.get("key", "") - - if request.method == "GET": - return await self._handle_metadata_get(key) - elif request.method == "PUT": - return await self._handle_metadata_put(key, request) - elif request.method == "DELETE": - return await self._handle_metadata_delete(key) - return web.Response( - text="Method not allowed", status=405, content_type="application/json" - ) - - async def _handle_metadata_get(self, key): - async with self.lock: - value = self.store.get(key) - if value is None: - return web.Response( - text="metadata not found", status=404, content_type="application/json" - ) - return web.Response(body=value, status=200, content_type="application/json") - - async def _handle_metadata_put(self, key, request): - data = await request.read() - async with self.lock: - self.store[key] = data - return web.Response( - text="metadata updated", status=200, content_type="application/json" - ) - - async def _handle_metadata_delete(self, key): - async with self.lock: - if key not in self.store: - return web.Response( - text="metadata not found", - status=404, - content_type="application/json", - ) - del self.store[key] - return web.Response( - text="metadata deleted", status=200, content_type="application/json" - ) - async def _handle_route(self, request: web.Request): method = request.method if method == "PUT": diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index bdba72579..8c9f910b3 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -1,45 +1,14 @@ import json import logging -import os -import uuid from dataclasses import dataclass +from typing import Optional logger = logging.getLogger(__name__) -@dataclass -class MooncakeTransferEngineConfig: - local_hostname: str - metadata_server: str - protocol: str - device_name: str - - @staticmethod - def from_file(file_path: str) -> "MooncakeTransferEngineConfig": - """Load the config from a JSON file.""" - with open(file_path) as fin: - config = json.load(fin) - return MooncakeTransferEngineConfig( - local_hostname=config.get("local_hostname", None), - metadata_server=config.get("metadata_server"), - protocol=config.get("protocol", "rdma"), - device_name=config.get("device_name", ""), - ) - - @staticmethod - def load_from_env() -> "MooncakeTransferEngineConfig": - """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") - if config_file_path is None: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." - ) - return MooncakeTransferEngineConfig.from_file(config_file_path) - - class MooncakeTransferEngine: - def __init__(self): + def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None): try: from mooncake.engine import TransferEngine except ImportError as e: @@ -50,43 +19,43 @@ class MooncakeTransferEngine: ) from e self.engine = TransferEngine() + self.hostname = hostname + self.gpu_id = gpu_id + self.ib_device = ib_device - try: - self.config = MooncakeTransferEngineConfig.load_from_env() - logger.info("Mooncake Configuration loaded successfully.") - except ValueError as e: - logger.error(e) - raise - except Exception as exc: - logger.error("An error occurred while loading the configuration: %s", exc) - raise - - self.config = MooncakeTransferEngineConfig.load_from_env() - - session_suffix = "_" + str(uuid.uuid4()) - self.session_id = self.config.local_hostname + session_suffix self.initialize( - self.session_id, - self.config.metadata_server, - self.config.protocol, - self.config.device_name, + hostname=self.hostname, + device_name=self.ib_device, ) + self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" def register(self, ptr, length): - self.engine.register_memory(ptr, length) + ret_value = self.engine.register_memory(ptr, length) + if ret_value != 0: + logger.error("Mooncake memory registration failed.") + raise RuntimeError("Mooncake memory registration failed.") def deregister(self, ptr): - self.engine.unregister_memory(ptr) + ret_value = self.engine.unregister_memory(ptr) + if ret_value != 0: + logger.error("Mooncake memory deregistration failed.") + raise RuntimeError("Mooncake memory deregistration failed.") def initialize( self, - local_hostname: str, - metadata_server: str, - protocol: str, - device_name: str, + hostname: str, + device_name: Optional[str], ) -> None: """Initialize the mooncake instance.""" - self.engine.initialize(local_hostname, metadata_server, protocol, device_name) + ret_value = self.engine.initialize( + hostname, + "P2PHANDSHAKE", + "rdma", + device_name if device_name is not None else "", + ) + if ret_value != 0: + logger.error("Mooncake Transfer Engine initialization failed.") + raise RuntimeError("Mooncake Transfer Engine initialization failed.") def transfer_sync( self, session_id: str, buffer: int, peer_buffer_address: int, length: int @@ -97,12 +66,12 @@ class MooncakeTransferEngine: session_id, buffer, peer_buffer_address, length ) if ret < 0: - logger.error("Transfer Return Error") - raise Exception("Transfer Return Error") + logger.error("Mooncake Transfer Engine Return Error.") + raise RuntimeError("Mooncake Transfer Engine Return Error.") return ret def get_localhost(self): - return self.config.local_hostname + return self.hostname def get_session_id(self): return self.session_id diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index d513b13dd..692d014bb 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -103,7 +103,7 @@ class PrefillBootstrapQueue: kv_args.aux_item_lens = [ metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] - kv_args.ib_device = "mock-ib-device" + kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager = kv_manager_class( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 41bb65117..25ad02a77 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -196,6 +196,7 @@ class ServerArgs: disaggregation_mode: str = "null" disaggregation_bootstrap_port: int = 8998 disaggregation_transfer_backend: str = "mooncake" + disaggregation_ib_device: Optional[str] = None def __post_init__(self): # Expert parallelism @@ -1193,6 +1194,12 @@ class ServerArgs: default=ServerArgs.disaggregation_transfer_backend, help="The backend for disaggregation transfer. Default is mooncake.", ) + parser.add_argument( + "--disaggregation-ib-device", + type=str, + default=ServerArgs.disaggregation_ib_device, + help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):