[router][grpc] add warm up to grpc server (#11627)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -3,13 +3,13 @@ Standalone gRPC Server for SGLang - Fully separated from HTTP server.
|
||||
Uses GrpcRequestManager for orchestration without tokenization.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
@@ -35,7 +35,11 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
kill_process_tree,
|
||||
prepare_model_and_tokenizer,
|
||||
)
|
||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -884,6 +888,13 @@ async def serve_grpc(
|
||||
await server.start()
|
||||
logger.info(f"gRPC server listening on {listen_addr}")
|
||||
|
||||
# Start warmup in a separate thread
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup_grpc,
|
||||
args=(server_args, None),
|
||||
)
|
||||
warmup_thread.start()
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
stop_event = asyncio.Event()
|
||||
@@ -906,6 +917,11 @@ async def serve_grpc(
|
||||
# Stop the gRPC server
|
||||
await server.stop(5.0)
|
||||
|
||||
# Wait for warmup thread to finish
|
||||
if warmup_thread.is_alive():
|
||||
logger.info("Waiting for warmup thread to finish...")
|
||||
warmup_thread.join(timeout=5.0)
|
||||
|
||||
# Terminate scheduler processes before exiting to avoid atexit hang
|
||||
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
|
||||
for i, proc in enumerate(scheduler_procs):
|
||||
@@ -921,3 +937,158 @@ async def serve_grpc(
|
||||
proc.join(timeout=1.0)
|
||||
|
||||
logger.info("All scheduler processes terminated")
|
||||
|
||||
|
||||
def _execute_grpc_server_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection],
|
||||
):
|
||||
"""Execute warmup for gRPC server by checking health and sending test request."""
|
||||
try:
|
||||
# Connect to the gRPC server
|
||||
grpc_url = f"{server_args.host}:{server_args.port}"
|
||||
channel = grpc.insecure_channel(
|
||||
grpc_url,
|
||||
options=[
|
||||
("grpc.max_send_message_length", 1024 * 1024 * 256),
|
||||
("grpc.max_receive_message_length", 1024 * 1024 * 256),
|
||||
],
|
||||
)
|
||||
stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel)
|
||||
|
||||
# Wait until the server is launched (poll GetModelInfo)
|
||||
success = False
|
||||
last_error = None
|
||||
for _ in range(120):
|
||||
time.sleep(1)
|
||||
try:
|
||||
request = sglang_scheduler_pb2.GetModelInfoRequest()
|
||||
response = stub.GetModelInfo(request, timeout=5)
|
||||
success = True
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
pass
|
||||
|
||||
if not success:
|
||||
error_msg = f"gRPC server warmup failed: Could not connect to server after 120 seconds. Last error: {last_error}"
|
||||
logger.error(error_msg)
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(error_msg)
|
||||
channel.close()
|
||||
kill_process_tree(os.getpid())
|
||||
return False
|
||||
|
||||
# Get model info to determine if it's generation or embedding
|
||||
is_generation = response.is_generation
|
||||
|
||||
# Send a warmup request
|
||||
logger.info("Sending warmup request to gRPC server...")
|
||||
max_new_tokens = 8 if is_generation else 1
|
||||
|
||||
if is_generation:
|
||||
# Create tokenized input for warmup
|
||||
warmup_request = sglang_scheduler_pb2.GenerateRequest(
|
||||
request_id=f"WARMUP_{time.time()}",
|
||||
tokenized=sglang_scheduler_pb2.TokenizedInput(
|
||||
input_ids=[
|
||||
954,
|
||||
15541,
|
||||
2181,
|
||||
23496,
|
||||
1476,
|
||||
64710,
|
||||
280,
|
||||
], # Simple token sequence
|
||||
original_text="The capital city of France is",
|
||||
),
|
||||
sampling_params=sglang_scheduler_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
max_new_tokens=max_new_tokens,
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Send the warmup request
|
||||
try:
|
||||
responses = list(stub.Generate(warmup_request, timeout=600))
|
||||
# Check if we got a valid response
|
||||
if responses and not responses[-1].HasField("error"):
|
||||
logger.info("gRPC warmup request completed successfully")
|
||||
success = True
|
||||
else:
|
||||
error_msg = (
|
||||
responses[-1].error.message if responses else "No response"
|
||||
)
|
||||
logger.warning(f"gRPC warmup request returned error: {error_msg}")
|
||||
success = False
|
||||
except Exception as e:
|
||||
error_msg = f"gRPC warmup request failed: {e}"
|
||||
logger.error(error_msg)
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(error_msg)
|
||||
channel.close()
|
||||
kill_process_tree(os.getpid())
|
||||
return False
|
||||
else:
|
||||
# For embedding models
|
||||
warmup_request = sglang_scheduler_pb2.EmbedRequest(
|
||||
request_id=f"WARMUP_{time.time()}",
|
||||
tokenized=sglang_scheduler_pb2.TokenizedInput(
|
||||
input_ids=[10, 11, 12],
|
||||
original_text="test embedding",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
response = stub.Embed(warmup_request, timeout=600)
|
||||
if not response.HasField("error"):
|
||||
logger.info("gRPC warmup request completed successfully")
|
||||
success = True
|
||||
else:
|
||||
logger.warning(
|
||||
f"gRPC warmup request returned error: {response.error.message}"
|
||||
)
|
||||
success = False
|
||||
except Exception as e:
|
||||
error_msg = f"gRPC warmup request failed: {e}"
|
||||
logger.error(error_msg)
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(error_msg)
|
||||
channel.close()
|
||||
kill_process_tree(os.getpid())
|
||||
return False
|
||||
|
||||
channel.close()
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"gRPC warmup failed with exception: {e}\n{get_exception_traceback()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(error_msg)
|
||||
try:
|
||||
channel.close()
|
||||
except Exception:
|
||||
pass
|
||||
kill_process_tree(os.getpid())
|
||||
return False
|
||||
|
||||
|
||||
def _wait_and_warmup_grpc(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection],
|
||||
):
|
||||
"""Wait for gRPC server to be ready and execute warmup."""
|
||||
if not server_args.skip_server_warmup:
|
||||
if not _execute_grpc_server_warmup(server_args, pipe_finish_writer):
|
||||
return
|
||||
else:
|
||||
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
Reference in New Issue
Block a user