From 96fe2d0f15a3907f3c083d70807f2d081b9a748c Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 1 Oct 2025 11:09:21 -0400 Subject: [PATCH] [router] add pd service in grpc router for pd (#11120) --- .../srt/entrypoints/grpc_request_manager.py | 25 ++++++++++++++++++ python/sglang/srt/entrypoints/grpc_server.py | 26 ++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index 1891e9e2a..71d77bfc2 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -19,6 +19,7 @@ import grpc import zmq import zmq.asyncio +from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -146,11 +147,19 @@ class GrpcRequestManager: self.crash_dump_request_list = [] self.crash_dump_performed = False + # Bootstrap server for disaggregation mode + self.bootstrap_server = start_disagg_service(server_args) + logger.info( f"GrpcRequestManager initialized with ZMQ IPC: " f"recv={port_args.detokenizer_ipc_name}, " f"send={port_args.scheduler_input_ipc_name}" ) + if self.bootstrap_server: + logger.info( + f"Bootstrap server started for disaggregation mode: " + f"{server_args.disaggregation_mode}" + ) async def generate_request( self, @@ -759,6 +768,22 @@ class GrpcRequestManager: state.finished = True state.event.set() + # Wait for tasks to complete + if self.asyncio_tasks: + await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True) + + # Shutdown bootstrap server if running + if self.bootstrap_server: + logger.info("Shutting down bootstrap server") + try: + if hasattr(self.bootstrap_server, "shutdown"): + if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown): + await self.bootstrap_server.shutdown() + else: + self.bootstrap_server.shutdown() + except Exception as e: + logger.warning(f"Error shutting down bootstrap server: {e}") + # Close ZMQ sockets self.recv_from_scheduler.close() self.send_to_scheduler.close() diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index 232461893..55712c177 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -793,6 +793,28 @@ def main(): # Logging parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") + # Disaggregation mode arguments + parser.add_argument( + "--disaggregation-mode", + type=str, + default="null", + choices=["null", "prefill", "decode"], + help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated', + ) + parser.add_argument( + "--disaggregation-transfer-backend", + type=str, + default="mooncake", + choices=["mooncake", "nixl", "ascend", "fake"], + help="The backend for disaggregation transfer. Default is mooncake.", + ) + parser.add_argument( + "--disaggregation-bootstrap-port", + type=int, + default=8998, + help="Bootstrap server port on the prefill server. Default is 8998.", + ) + args = parser.parse_args() # Convert to ServerArgs with gRPC host and port @@ -808,7 +830,9 @@ def main(): attention_backend=args.attention_backend, lora_paths=args.lora_paths.split(",") if args.lora_paths else None, log_level=args.log_level, - # Override with gRPC server host and port + disaggregation_mode=args.disaggregation_mode, + disaggregation_transfer_backend=args.disaggregation_transfer_backend, + disaggregation_bootstrap_port=args.disaggregation_bootstrap_port, host=args.host, port=args.port, )