[router] add grpc router pd mode for chat and generate (#11140)
This commit is contained in:
@@ -67,8 +67,8 @@ dependencies = [
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.24",
|
||||
"grpcio==1.74.0", # keep it align with compile_proto.py
|
||||
"grpcio-tools==1.74.0" # keep it align with compile_proto.py
|
||||
"grpcio==1.75.1", # keep it align with compile_proto.py
|
||||
"grpcio-tools==1.75.1" # keep it align with compile_proto.py
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -19,7 +19,6 @@ import grpc
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOutput,
|
||||
@@ -111,6 +110,7 @@ class GrpcRequestManager:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
bootstrap_server=None,
|
||||
):
|
||||
"""Initialize the gRPC request manager."""
|
||||
self.server_args = server_args
|
||||
@@ -147,8 +147,8 @@ class GrpcRequestManager:
|
||||
self.crash_dump_request_list = []
|
||||
self.crash_dump_performed = False
|
||||
|
||||
# Bootstrap server for disaggregation mode
|
||||
self.bootstrap_server = start_disagg_service(server_args)
|
||||
# Bootstrap server (passed from serve_grpc, not started here)
|
||||
self.bootstrap_server = bootstrap_server
|
||||
|
||||
logger.info(
|
||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||
@@ -157,7 +157,7 @@ class GrpcRequestManager:
|
||||
)
|
||||
if self.bootstrap_server:
|
||||
logger.info(
|
||||
f"Bootstrap server started for disaggregation mode: "
|
||||
f"Bootstrap server initialized for disaggregation mode: "
|
||||
f"{server_args.disaggregation_mode}"
|
||||
)
|
||||
|
||||
|
||||
@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
token_ids_logprob=None,
|
||||
)
|
||||
|
||||
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
|
||||
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
health_request.bootstrap_room = 0
|
||||
|
||||
logger.info(f"Sending health check request to request manager...")
|
||||
|
||||
# Submit and wait for response
|
||||
@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
# Convert sampling params
|
||||
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
|
||||
|
||||
# Extract disaggregated params if present
|
||||
bootstrap_host = None
|
||||
bootstrap_port = None
|
||||
bootstrap_room = None
|
||||
if grpc_req.HasField("disaggregated_params"):
|
||||
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
|
||||
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
|
||||
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
|
||||
|
||||
# Create request
|
||||
return TokenizedGenerateReqInput(
|
||||
rid=grpc_req.request_id,
|
||||
@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
token_ids_logprob=(
|
||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||
),
|
||||
bootstrap_host=bootstrap_host,
|
||||
bootstrap_port=bootstrap_port,
|
||||
bootstrap_room=bootstrap_room,
|
||||
)
|
||||
|
||||
def _convert_embed_request(
|
||||
@@ -659,6 +677,16 @@ async def serve_grpc(
|
||||
):
|
||||
"""Start the standalone gRPC server with integrated scheduler."""
|
||||
|
||||
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
|
||||
# This ensures the bootstrap server is ready when prefill schedulers try to register
|
||||
bootstrap_server = None
|
||||
if server_args.disaggregation_mode == "prefill":
|
||||
bootstrap_server = start_disagg_service(server_args)
|
||||
if bootstrap_server:
|
||||
logger.info(
|
||||
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
|
||||
)
|
||||
|
||||
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||
logger.info("Launching scheduler process(es)...")
|
||||
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
||||
@@ -682,9 +710,11 @@ async def serve_grpc(
|
||||
}
|
||||
|
||||
# Create request manager with the correct port args
|
||||
# Note: We pass None for bootstrap_server since it's already started above
|
||||
request_manager = GrpcRequestManager(
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
bootstrap_server=bootstrap_server,
|
||||
)
|
||||
|
||||
# Create gRPC server
|
||||
@@ -764,79 +794,9 @@ def main():
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
||||
|
||||
# Server arguments
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model-path", type=str, required=True, help="Model path")
|
||||
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
|
||||
parser.add_argument("--context-length", type=int, help="Context length")
|
||||
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
|
||||
|
||||
# Runtime arguments
|
||||
parser.add_argument(
|
||||
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
|
||||
)
|
||||
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||
|
||||
# Disaggregation mode arguments
|
||||
parser.add_argument(
|
||||
"--disaggregation-mode",
|
||||
type=str,
|
||||
default="null",
|
||||
choices=["null", "prefill", "decode"],
|
||||
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-transfer-backend",
|
||||
type=str,
|
||||
default="mooncake",
|
||||
choices=["mooncake", "nixl", "ascend", "fake"],
|
||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-bootstrap-port",
|
||||
type=int,
|
||||
default=8998,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert to ServerArgs with gRPC host and port
|
||||
server_args = ServerArgs(
|
||||
model_path=args.model_path,
|
||||
tokenizer_path=args.tokenizer_path or args.model_path,
|
||||
context_length=args.context_length,
|
||||
tp_size=args.tp_size,
|
||||
dp_size=args.dp_size,
|
||||
max_running_requests=args.max_running_requests,
|
||||
max_total_tokens=args.max_total_tokens,
|
||||
max_prefill_tokens=args.max_prefill_tokens,
|
||||
attention_backend=args.attention_backend,
|
||||
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||
log_level=args.log_level,
|
||||
disaggregation_mode=args.disaggregation_mode,
|
||||
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
|
||||
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
# Run server
|
||||
asyncio.run(
|
||||
|
||||
Reference in New Issue
Block a user