[router][grpc] add warm up to grpc server (#11627)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -3,13 +3,13 @@ Standalone gRPC Server for SGLang - Fully separated from HTTP server.
|
|||||||
Uses GrpcRequestManager for orchestration without tokenization.
|
Uses GrpcRequestManager for orchestration without tokenization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||||
@@ -35,7 +35,11 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
from sglang.srt.utils import (
|
||||||
|
configure_logger,
|
||||||
|
kill_process_tree,
|
||||||
|
prepare_model_and_tokenizer,
|
||||||
|
)
|
||||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
@@ -884,6 +888,13 @@ async def serve_grpc(
|
|||||||
await server.start()
|
await server.start()
|
||||||
logger.info(f"gRPC server listening on {listen_addr}")
|
logger.info(f"gRPC server listening on {listen_addr}")
|
||||||
|
|
||||||
|
# Start warmup in a separate thread
|
||||||
|
warmup_thread = threading.Thread(
|
||||||
|
target=_wait_and_warmup_grpc,
|
||||||
|
args=(server_args, None),
|
||||||
|
)
|
||||||
|
warmup_thread.start()
|
||||||
|
|
||||||
# Handle shutdown signals
|
# Handle shutdown signals
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
@@ -906,6 +917,11 @@ async def serve_grpc(
|
|||||||
# Stop the gRPC server
|
# Stop the gRPC server
|
||||||
await server.stop(5.0)
|
await server.stop(5.0)
|
||||||
|
|
||||||
|
# Wait for warmup thread to finish
|
||||||
|
if warmup_thread.is_alive():
|
||||||
|
logger.info("Waiting for warmup thread to finish...")
|
||||||
|
warmup_thread.join(timeout=5.0)
|
||||||
|
|
||||||
# Terminate scheduler processes before exiting to avoid atexit hang
|
# Terminate scheduler processes before exiting to avoid atexit hang
|
||||||
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
|
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
|
||||||
for i, proc in enumerate(scheduler_procs):
|
for i, proc in enumerate(scheduler_procs):
|
||||||
@@ -921,3 +937,158 @@ async def serve_grpc(
|
|||||||
proc.join(timeout=1.0)
|
proc.join(timeout=1.0)
|
||||||
|
|
||||||
logger.info("All scheduler processes terminated")
|
logger.info("All scheduler processes terminated")
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_grpc_server_warmup(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
pipe_finish_writer: Optional[mp.connection.Connection],
|
||||||
|
):
|
||||||
|
"""Execute warmup for gRPC server by checking health and sending test request."""
|
||||||
|
try:
|
||||||
|
# Connect to the gRPC server
|
||||||
|
grpc_url = f"{server_args.host}:{server_args.port}"
|
||||||
|
channel = grpc.insecure_channel(
|
||||||
|
grpc_url,
|
||||||
|
options=[
|
||||||
|
("grpc.max_send_message_length", 1024 * 1024 * 256),
|
||||||
|
("grpc.max_receive_message_length", 1024 * 1024 * 256),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel)
|
||||||
|
|
||||||
|
# Wait until the server is launched (poll GetModelInfo)
|
||||||
|
success = False
|
||||||
|
last_error = None
|
||||||
|
for _ in range(120):
|
||||||
|
time.sleep(1)
|
||||||
|
try:
|
||||||
|
request = sglang_scheduler_pb2.GetModelInfoRequest()
|
||||||
|
response = stub.GetModelInfo(request, timeout=5)
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
error_msg = f"gRPC server warmup failed: Could not connect to server after 120 seconds. Last error: {last_error}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(error_msg)
|
||||||
|
channel.close()
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get model info to determine if it's generation or embedding
|
||||||
|
is_generation = response.is_generation
|
||||||
|
|
||||||
|
# Send a warmup request
|
||||||
|
logger.info("Sending warmup request to gRPC server...")
|
||||||
|
max_new_tokens = 8 if is_generation else 1
|
||||||
|
|
||||||
|
if is_generation:
|
||||||
|
# Create tokenized input for warmup
|
||||||
|
warmup_request = sglang_scheduler_pb2.GenerateRequest(
|
||||||
|
request_id=f"WARMUP_{time.time()}",
|
||||||
|
tokenized=sglang_scheduler_pb2.TokenizedInput(
|
||||||
|
input_ids=[
|
||||||
|
954,
|
||||||
|
15541,
|
||||||
|
2181,
|
||||||
|
23496,
|
||||||
|
1476,
|
||||||
|
64710,
|
||||||
|
280,
|
||||||
|
], # Simple token sequence
|
||||||
|
original_text="The capital city of France is",
|
||||||
|
),
|
||||||
|
sampling_params=sglang_scheduler_pb2.SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
),
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send the warmup request
|
||||||
|
try:
|
||||||
|
responses = list(stub.Generate(warmup_request, timeout=600))
|
||||||
|
# Check if we got a valid response
|
||||||
|
if responses and not responses[-1].HasField("error"):
|
||||||
|
logger.info("gRPC warmup request completed successfully")
|
||||||
|
success = True
|
||||||
|
else:
|
||||||
|
error_msg = (
|
||||||
|
responses[-1].error.message if responses else "No response"
|
||||||
|
)
|
||||||
|
logger.warning(f"gRPC warmup request returned error: {error_msg}")
|
||||||
|
success = False
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"gRPC warmup request failed: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(error_msg)
|
||||||
|
channel.close()
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# For embedding models
|
||||||
|
warmup_request = sglang_scheduler_pb2.EmbedRequest(
|
||||||
|
request_id=f"WARMUP_{time.time()}",
|
||||||
|
tokenized=sglang_scheduler_pb2.TokenizedInput(
|
||||||
|
input_ids=[10, 11, 12],
|
||||||
|
original_text="test embedding",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = stub.Embed(warmup_request, timeout=600)
|
||||||
|
if not response.HasField("error"):
|
||||||
|
logger.info("gRPC warmup request completed successfully")
|
||||||
|
success = True
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"gRPC warmup request returned error: {response.error.message}"
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"gRPC warmup request failed: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(error_msg)
|
||||||
|
channel.close()
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel.close()
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = (
|
||||||
|
f"gRPC warmup failed with exception: {e}\n{get_exception_traceback()}"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(error_msg)
|
||||||
|
try:
|
||||||
|
channel.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_and_warmup_grpc(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
pipe_finish_writer: Optional[mp.connection.Connection],
|
||||||
|
):
|
||||||
|
"""Wait for gRPC server to be ready and execute warmup."""
|
||||||
|
if not server_args.skip_server_warmup:
|
||||||
|
if not _execute_grpc_server_warmup(server_args, pipe_finish_writer):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
|
||||||
|
|
||||||
|
logger.info("The server is fired up and ready to roll!")
|
||||||
|
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send("ready")
|
||||||
|
|||||||
Reference in New Issue
Block a user