[router] add pd service in grpc router for pd (#11120)
This commit is contained in:
@@ -19,6 +19,7 @@ import grpc
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
@@ -146,11 +147,19 @@ class GrpcRequestManager:
|
||||
self.crash_dump_request_list = []
|
||||
self.crash_dump_performed = False
|
||||
|
||||
# Bootstrap server for disaggregation mode
|
||||
self.bootstrap_server = start_disagg_service(server_args)
|
||||
|
||||
logger.info(
|
||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||
f"recv={port_args.detokenizer_ipc_name}, "
|
||||
f"send={port_args.scheduler_input_ipc_name}"
|
||||
)
|
||||
if self.bootstrap_server:
|
||||
logger.info(
|
||||
f"Bootstrap server started for disaggregation mode: "
|
||||
f"{server_args.disaggregation_mode}"
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
@@ -759,6 +768,22 @@ class GrpcRequestManager:
|
||||
state.finished = True
|
||||
state.event.set()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if self.asyncio_tasks:
|
||||
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
||||
|
||||
# Shutdown bootstrap server if running
|
||||
if self.bootstrap_server:
|
||||
logger.info("Shutting down bootstrap server")
|
||||
try:
|
||||
if hasattr(self.bootstrap_server, "shutdown"):
|
||||
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
|
||||
await self.bootstrap_server.shutdown()
|
||||
else:
|
||||
self.bootstrap_server.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down bootstrap server: {e}")
|
||||
|
||||
# Close ZMQ sockets
|
||||
self.recv_from_scheduler.close()
|
||||
self.send_to_scheduler.close()
|
||||
|
||||
@@ -793,6 +793,28 @@ def main():
|
||||
# Logging
|
||||
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||
|
||||
# Disaggregation mode arguments
|
||||
parser.add_argument(
|
||||
"--disaggregation-mode",
|
||||
type=str,
|
||||
default="null",
|
||||
choices=["null", "prefill", "decode"],
|
||||
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-transfer-backend",
|
||||
type=str,
|
||||
default="mooncake",
|
||||
choices=["mooncake", "nixl", "ascend", "fake"],
|
||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-bootstrap-port",
|
||||
type=int,
|
||||
default=8998,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert to ServerArgs with gRPC host and port
|
||||
@@ -808,7 +830,9 @@ def main():
|
||||
attention_backend=args.attention_backend,
|
||||
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||
log_level=args.log_level,
|
||||
# Override with gRPC server host and port
|
||||
disaggregation_mode=args.disaggregation_mode,
|
||||
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
|
||||
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user