From 325951460ff297145f7e86eb2dc6a1700a301dde Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 14 Oct 2025 19:11:16 -0400 Subject: [PATCH] [router][grpc] add warm up to grpc server (#11627) Co-authored-by: Chang Su --- python/sglang/srt/entrypoints/grpc_server.py | 175 ++++++++++++++++++- 1 file changed, 173 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index fa47f3860..9cec138f4 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -3,13 +3,13 @@ Standalone gRPC Server for SGLang - Fully separated from HTTP server. Uses GrpcRequestManager for orchestration without tokenization. """ -import argparse import asyncio import dataclasses import logging import multiprocessing as mp import os import signal +import threading import time from concurrent import futures from typing import AsyncIterator, Dict, Optional, Tuple @@ -35,7 +35,11 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer +from sglang.srt.utils import ( + configure_logger, + kill_process_tree, + prepare_model_and_tokenizer, +) from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import get_exception_traceback @@ -884,6 +888,13 @@ async def serve_grpc( await server.start() logger.info(f"gRPC server listening on {listen_addr}") + # Start warmup in a separate thread + warmup_thread = threading.Thread( + target=_wait_and_warmup_grpc, + args=(server_args, None), + ) + warmup_thread.start() + # Handle shutdown signals loop = asyncio.get_running_loop() stop_event = asyncio.Event() @@ -906,6 +917,11 @@ async def serve_grpc( # Stop the gRPC server await server.stop(5.0) + # Wait for warmup thread to finish + if warmup_thread.is_alive(): + logger.info("Waiting for warmup thread to finish...") + warmup_thread.join(timeout=5.0) + # Terminate scheduler processes before exiting to avoid atexit hang # The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt for i, proc in enumerate(scheduler_procs): @@ -921,3 +937,158 @@ async def serve_grpc( proc.join(timeout=1.0) logger.info("All scheduler processes terminated") + + +def _execute_grpc_server_warmup( + server_args: ServerArgs, + pipe_finish_writer: Optional[mp.connection.Connection], +): + """Execute warmup for gRPC server by checking health and sending test request.""" + try: + # Connect to the gRPC server + grpc_url = f"{server_args.host}:{server_args.port}" + channel = grpc.insecure_channel( + grpc_url, + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ], + ) + stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel) + + # Wait until the server is launched (poll GetModelInfo) + success = False + last_error = None + for _ in range(120): + time.sleep(1) + try: + request = sglang_scheduler_pb2.GetModelInfoRequest() + response = stub.GetModelInfo(request, timeout=5) + success = True + break + except Exception as e: + last_error = str(e) + pass + + if not success: + error_msg = f"gRPC server warmup failed: Could not connect to server after 120 seconds. Last error: {last_error}" + logger.error(error_msg) + if pipe_finish_writer is not None: + pipe_finish_writer.send(error_msg) + channel.close() + kill_process_tree(os.getpid()) + return False + + # Get model info to determine if it's generation or embedding + is_generation = response.is_generation + + # Send a warmup request + logger.info("Sending warmup request to gRPC server...") + max_new_tokens = 8 if is_generation else 1 + + if is_generation: + # Create tokenized input for warmup + warmup_request = sglang_scheduler_pb2.GenerateRequest( + request_id=f"WARMUP_{time.time()}", + tokenized=sglang_scheduler_pb2.TokenizedInput( + input_ids=[ + 954, + 15541, + 2181, + 23496, + 1476, + 64710, + 280, + ], # Simple token sequence + original_text="The capital city of France is", + ), + sampling_params=sglang_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=max_new_tokens, + ), + stream=False, + ) + + # Send the warmup request + try: + responses = list(stub.Generate(warmup_request, timeout=600)) + # Check if we got a valid response + if responses and not responses[-1].HasField("error"): + logger.info("gRPC warmup request completed successfully") + success = True + else: + error_msg = ( + responses[-1].error.message if responses else "No response" + ) + logger.warning(f"gRPC warmup request returned error: {error_msg}") + success = False + except Exception as e: + error_msg = f"gRPC warmup request failed: {e}" + logger.error(error_msg) + if pipe_finish_writer is not None: + pipe_finish_writer.send(error_msg) + channel.close() + kill_process_tree(os.getpid()) + return False + else: + # For embedding models + warmup_request = sglang_scheduler_pb2.EmbedRequest( + request_id=f"WARMUP_{time.time()}", + tokenized=sglang_scheduler_pb2.TokenizedInput( + input_ids=[10, 11, 12], + original_text="test embedding", + ), + ) + + try: + response = stub.Embed(warmup_request, timeout=600) + if not response.HasField("error"): + logger.info("gRPC warmup request completed successfully") + success = True + else: + logger.warning( + f"gRPC warmup request returned error: {response.error.message}" + ) + success = False + except Exception as e: + error_msg = f"gRPC warmup request failed: {e}" + logger.error(error_msg) + if pipe_finish_writer is not None: + pipe_finish_writer.send(error_msg) + channel.close() + kill_process_tree(os.getpid()) + return False + + channel.close() + return success + + except Exception as e: + error_msg = ( + f"gRPC warmup failed with exception: {e}\n{get_exception_traceback()}" + ) + logger.error(error_msg) + if pipe_finish_writer is not None: + pipe_finish_writer.send(error_msg) + try: + channel.close() + except Exception: + pass + kill_process_tree(os.getpid()) + return False + + +def _wait_and_warmup_grpc( + server_args: ServerArgs, + pipe_finish_writer: Optional[mp.connection.Connection], +): + """Wait for gRPC server to be ready and execute warmup.""" + if not server_args.skip_server_warmup: + if not _execute_grpc_server_warmup(server_args, pipe_finish_writer): + return + else: + logger.info("Skipping gRPC server warmup (skip_server_warmup=True)") + + logger.info("The server is fired up and ready to roll!") + + if pipe_finish_writer is not None: + pipe_finish_writer.send("ready")