[router] add grpc router pd mode for chat and generate (#11140)

This commit is contained in:
Simo Lin
2025-10-04 09:58:28 -04:00
committed by GitHub
parent ffd03a9bd3
commit d736e0b65e
11 changed files with 3169 additions and 1080 deletions

View File

@@ -67,8 +67,8 @@ dependencies = [
"uvicorn",
"uvloop",
"xgrammar==0.1.24",
"grpcio==1.74.0", # keep it align with compile_proto.py
"grpcio-tools==1.74.0" # keep it align with compile_proto.py
"grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.75.1" # keep it align with compile_proto.py
]
[project.optional-dependencies]

View File

@@ -19,7 +19,6 @@ import grpc
import zmq
import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOutput,
@@ -111,6 +110,7 @@ class GrpcRequestManager:
self,
server_args: ServerArgs,
port_args: PortArgs,
bootstrap_server=None,
):
"""Initialize the gRPC request manager."""
self.server_args = server_args
@@ -147,8 +147,8 @@ class GrpcRequestManager:
self.crash_dump_request_list = []
self.crash_dump_performed = False
# Bootstrap server for disaggregation mode
self.bootstrap_server = start_disagg_service(server_args)
# Bootstrap server (passed from serve_grpc, not started here)
self.bootstrap_server = bootstrap_server
logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: "
@@ -157,7 +157,7 @@ class GrpcRequestManager:
)
if self.bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode: "
f"Bootstrap server initialized for disaggregation mode: "
f"{server_args.disaggregation_mode}"
)

View File

@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
import grpc
from grpc_reflection.v1alpha import reflection
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=None,
)
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
health_request.bootstrap_room = 0
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert sampling params
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
# Extract disaggregated params if present
bootstrap_host = None
bootstrap_port = None
bootstrap_room = None
if grpc_req.HasField("disaggregated_params"):
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
# Create request
return TokenizedGenerateReqInput(
rid=grpc_req.request_id,
@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
)
def _convert_embed_request(
@@ -659,6 +677,16 @@ async def serve_grpc(
):
"""Start the standalone gRPC server with integrated scheduler."""
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
# This ensures the bootstrap server is ready when prefill schedulers try to register
bootstrap_server = None
if server_args.disaggregation_mode == "prefill":
bootstrap_server = start_disagg_service(server_args)
if bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
)
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
@@ -682,9 +710,11 @@ async def serve_grpc(
}
# Create request manager with the correct port args
# Note: We pass None for bootstrap_server since it's already started above
request_manager = GrpcRequestManager(
server_args=server_args,
port_args=port_args,
bootstrap_server=bootstrap_server,
)
# Create gRPC server
@@ -764,79 +794,9 @@ def main():
mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
# Model arguments
parser.add_argument("--model-path", type=str, required=True, help="Model path")
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
parser.add_argument("--context-length", type=int, help="Context length")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
# Runtime arguments
parser.add_argument(
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
)
parser.add_argument(
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
)
parser.add_argument(
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
)
parser.add_argument(
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
)
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
# Logging
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
# Disaggregation mode arguments
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
default="mooncake",
choices=["mooncake", "nixl", "ascend", "fake"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=8998,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
# Convert to ServerArgs with gRPC host and port
server_args = ServerArgs(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path or args.model_path,
context_length=args.context_length,
tp_size=args.tp_size,
dp_size=args.dp_size,
max_running_requests=args.max_running_requests,
max_total_tokens=args.max_total_tokens,
max_prefill_tokens=args.max_prefill_tokens,
attention_backend=args.attention_backend,
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
log_level=args.log_level,
disaggregation_mode=args.disaggregation_mode,
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
host=args.host,
port=args.port,
)
server_args = ServerArgs.from_cli_args(args)
# Run server
asyncio.run(