[router] add pd service in grpc router for pd (#11120)
This commit is contained in:
@@ -19,6 +19,7 @@ import grpc
|
|||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
@@ -146,11 +147,19 @@ class GrpcRequestManager:
|
|||||||
self.crash_dump_request_list = []
|
self.crash_dump_request_list = []
|
||||||
self.crash_dump_performed = False
|
self.crash_dump_performed = False
|
||||||
|
|
||||||
|
# Bootstrap server for disaggregation mode
|
||||||
|
self.bootstrap_server = start_disagg_service(server_args)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||||
f"recv={port_args.detokenizer_ipc_name}, "
|
f"recv={port_args.detokenizer_ipc_name}, "
|
||||||
f"send={port_args.scheduler_input_ipc_name}"
|
f"send={port_args.scheduler_input_ipc_name}"
|
||||||
)
|
)
|
||||||
|
if self.bootstrap_server:
|
||||||
|
logger.info(
|
||||||
|
f"Bootstrap server started for disaggregation mode: "
|
||||||
|
f"{server_args.disaggregation_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
@@ -759,6 +768,22 @@ class GrpcRequestManager:
|
|||||||
state.finished = True
|
state.finished = True
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
|
# Wait for tasks to complete
|
||||||
|
if self.asyncio_tasks:
|
||||||
|
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
||||||
|
|
||||||
|
# Shutdown bootstrap server if running
|
||||||
|
if self.bootstrap_server:
|
||||||
|
logger.info("Shutting down bootstrap server")
|
||||||
|
try:
|
||||||
|
if hasattr(self.bootstrap_server, "shutdown"):
|
||||||
|
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
|
||||||
|
await self.bootstrap_server.shutdown()
|
||||||
|
else:
|
||||||
|
self.bootstrap_server.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error shutting down bootstrap server: {e}")
|
||||||
|
|
||||||
# Close ZMQ sockets
|
# Close ZMQ sockets
|
||||||
self.recv_from_scheduler.close()
|
self.recv_from_scheduler.close()
|
||||||
self.send_to_scheduler.close()
|
self.send_to_scheduler.close()
|
||||||
|
|||||||
@@ -793,6 +793,28 @@ def main():
|
|||||||
# Logging
|
# Logging
|
||||||
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||||
|
|
||||||
|
# Disaggregation mode arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-mode",
|
||||||
|
type=str,
|
||||||
|
default="null",
|
||||||
|
choices=["null", "prefill", "decode"],
|
||||||
|
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-transfer-backend",
|
||||||
|
type=str,
|
||||||
|
default="mooncake",
|
||||||
|
choices=["mooncake", "nixl", "ascend", "fake"],
|
||||||
|
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-bootstrap-port",
|
||||||
|
type=int,
|
||||||
|
default=8998,
|
||||||
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert to ServerArgs with gRPC host and port
|
# Convert to ServerArgs with gRPC host and port
|
||||||
@@ -808,7 +830,9 @@ def main():
|
|||||||
attention_backend=args.attention_backend,
|
attention_backend=args.attention_backend,
|
||||||
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||||
log_level=args.log_level,
|
log_level=args.log_level,
|
||||||
# Override with gRPC server host and port
|
disaggregation_mode=args.disaggregation_mode,
|
||||||
|
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
|
||||||
|
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user