[router][grpc] Add serve_grpc to launch_server and log id for HealthCheck (#11564)
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
"""Launch the inference server."""
|
"""Launch the inference server."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from sglang.srt.entrypoints.http_server import launch_server
|
|
||||||
from sglang.srt.server_args import prepare_server_args
|
from sglang.srt.server_args import prepare_server_args
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
|
||||||
@@ -11,6 +11,13 @@ if __name__ == "__main__":
|
|||||||
server_args = prepare_server_args(sys.argv[1:])
|
server_args = prepare_server_args(sys.argv[1:])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
launch_server(server_args)
|
if server_args.grpc_mode:
|
||||||
|
from sglang.srt.entrypoints.grpc_server import serve_grpc
|
||||||
|
|
||||||
|
asyncio.run(serve_grpc(server_args))
|
||||||
|
else:
|
||||||
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
|
|
||||||
|
launch_server(server_args)
|
||||||
finally:
|
finally:
|
||||||
kill_process_tree(os.getpid(), include_parent=False)
|
kill_process_tree(os.getpid(), include_parent=False)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ from grpc_reflection.v1alpha import reflection
|
|||||||
|
|
||||||
import sglang
|
import sglang
|
||||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||||
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
|
||||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||||
|
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
|
||||||
from sglang.srt.managers.data_parallel_controller import (
|
from sglang.srt.managers.data_parallel_controller import (
|
||||||
run_data_parallel_controller_process,
|
run_data_parallel_controller_process,
|
||||||
)
|
)
|
||||||
@@ -68,6 +68,8 @@ def _launch_scheduler_process_only(
|
|||||||
# Configure global environment
|
# Configure global environment
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
server_args.check_server_args()
|
server_args.check_server_args()
|
||||||
|
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
# Allocate ports for inter-process communications
|
# Allocate ports for inter-process communications
|
||||||
if port_args is None:
|
if port_args is None:
|
||||||
@@ -317,7 +319,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
Check the health of the inference server by sending a special request to generate one token.
|
Check the health of the inference server by sending a special request to generate one token.
|
||||||
Similar to HTTP server's /health endpoint.
|
Similar to HTTP server's /health endpoint.
|
||||||
"""
|
"""
|
||||||
logger.info("Receive health check request")
|
rid = f"HEALTH_CHECK_{time.time()}"
|
||||||
|
logger.info(f"Receive health check request: {rid}")
|
||||||
|
|
||||||
if self.request_manager.gracefully_exit:
|
if self.request_manager.gracefully_exit:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -328,7 +331,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create a special health check request
|
# Create a special health check request
|
||||||
rid = f"HEALTH_CHECK_{time.time()}"
|
|
||||||
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
|
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
|
||||||
sampling_params.normalize(tokenizer=None)
|
sampling_params.normalize(tokenizer=None)
|
||||||
|
|
||||||
@@ -919,25 +921,3 @@ async def serve_grpc(
|
|||||||
proc.join(timeout=1.0)
|
proc.join(timeout=1.0)
|
||||||
|
|
||||||
logger.info("All scheduler processes terminated")
|
logger.info("All scheduler processes terminated")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main entry point for standalone gRPC server."""
|
|
||||||
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
|
||||||
mp.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
|
||||||
ServerArgs.add_cli_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
|
||||||
|
|
||||||
# Run server
|
|
||||||
asyncio.run(
|
|
||||||
serve_grpc(
|
|
||||||
server_args=server_args,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -326,10 +326,7 @@ message EmbedError {
|
|||||||
// Management Operations
|
// Management Operations
|
||||||
// =====================
|
// =====================
|
||||||
|
|
||||||
message HealthCheckRequest {
|
message HealthCheckRequest {}
|
||||||
// Input for health test generation (must be tokenized)
|
|
||||||
TokenizedInput tokenized = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message HealthCheckResponse {
|
message HealthCheckResponse {
|
||||||
bool healthy = 1;
|
bool healthy = 1;
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -320,10 +320,8 @@ class EmbedError(_message.Message):
|
|||||||
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
class HealthCheckRequest(_message.Message):
|
class HealthCheckRequest(_message.Message):
|
||||||
__slots__ = ("tokenized",)
|
__slots__ = ()
|
||||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
def __init__(self) -> None: ...
|
||||||
tokenized: TokenizedInput
|
|
||||||
def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
|
|
||||||
|
|
||||||
class HealthCheckResponse(_message.Message):
|
class HealthCheckResponse(_message.Message):
|
||||||
__slots__ = ("healthy", "message")
|
__slots__ = ("healthy", "message")
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ class ServerArgs:
|
|||||||
# HTTP server
|
# HTTP server
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
port: int = 30000
|
port: int = 30000
|
||||||
|
grpc_mode: bool = False
|
||||||
skip_server_warmup: bool = False
|
skip_server_warmup: bool = False
|
||||||
warmups: Optional[str] = None
|
warmups: Optional[str] = None
|
||||||
nccl_port: Optional[int] = None
|
nccl_port: Optional[int] = None
|
||||||
@@ -1516,6 +1517,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.port,
|
default=ServerArgs.port,
|
||||||
help="The port of the HTTP server.",
|
help="The port of the HTTP server.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grpc-mode",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, use gRPC server instead of HTTP server.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-server-warmup",
|
"--skip-server-warmup",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -169,8 +169,8 @@ impl SglangSchedulerClient {
|
|||||||
&self,
|
&self,
|
||||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
debug!("Sending health check request");
|
debug!("Sending health check request");
|
||||||
// Server ignores the request body and creates its own health check internally
|
// HealthCheckRequest is now empty - server generates its own health check internally
|
||||||
let request = Request::new(proto::HealthCheckRequest { tokenized: None });
|
let request = Request::new(proto::HealthCheckRequest {});
|
||||||
|
|
||||||
let mut client = self.client.clone();
|
let mut client = self.client.clone();
|
||||||
let response = client.health_check(request).await?;
|
let response = client.health_check(request).await?;
|
||||||
@@ -510,13 +510,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_proto_types_compilation() {
|
fn test_proto_types_compilation() {
|
||||||
let health_req = proto::HealthCheckRequest {
|
let _health_req = proto::HealthCheckRequest {};
|
||||||
tokenized: Some(proto::TokenizedInput {
|
// HealthCheckRequest is now empty - no fields to test
|
||||||
original_text: "test".to_string(),
|
|
||||||
input_ids: vec![1296],
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
assert!(health_req.tokenized.is_some());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -558,13 +553,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_health_check_request() {
|
fn test_health_check_request() {
|
||||||
let health_req = proto::HealthCheckRequest {
|
let _health_req = proto::HealthCheckRequest {};
|
||||||
tokenized: Some(proto::TokenizedInput {
|
// HealthCheckRequest is now empty - server generates its own test internally
|
||||||
original_text: "test".to_string(),
|
|
||||||
input_ids: vec![1296], // Mock token ID for "test"
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
assert!(health_req.tokenized.is_some());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -326,10 +326,7 @@ message EmbedError {
|
|||||||
// Management Operations
|
// Management Operations
|
||||||
// =====================
|
// =====================
|
||||||
|
|
||||||
message HealthCheckRequest {
|
message HealthCheckRequest {}
|
||||||
// Input for health test generation (must be tokenized)
|
|
||||||
TokenizedInput tokenized = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message HealthCheckResponse {
|
message HealthCheckResponse {
|
||||||
bool healthy = 1;
|
bool healthy = 1;
|
||||||
|
|||||||
Reference in New Issue
Block a user