[Improvements] Merge health check route (#8444)
Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate(request: Request) -> Response:
|
||||
"""Check the health of the inference server by generating one token."""
|
||||
"""
|
||||
Check the health of the inference server by sending a special request to generate one token.
|
||||
|
||||
If the server is running something, this request will be ignored, so it creates zero overhead.
|
||||
If the server is not running anything, this request will be run, so we know whether the server is healthy.
|
||||
"""
|
||||
|
||||
if _global_state.tokenizer_manager.gracefully_exit:
|
||||
logger.info("Health check request received during shutdown. Returning 503.")
|
||||
return Response(status_code=503)
|
||||
|
||||
if not _global_state.tokenizer_manager.server_status.is_healthy():
|
||||
return Response(status_code=503)
|
||||
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||
rid = f"HEALTH_CHECK_{time.time()}"
|
||||
|
||||
if _global_state.tokenizer_manager.is_image_gen:
|
||||
raise NotImplementedError()
|
||||
# Keep this branch for some internal use cases.
|
||||
raise NotImplementedError("Image generation is not supported yet.")
|
||||
elif _global_state.tokenizer_manager.is_generation:
|
||||
gri = GenerateReqInput(
|
||||
rid=rid,
|
||||
@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
|
||||
sampling_params=sampling_params,
|
||||
log_metrics=False,
|
||||
)
|
||||
if (
|
||||
_global_state.tokenizer_manager.server_args.disaggregation_mode
|
||||
!= DisaggregationMode.NULL
|
||||
):
|
||||
gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
gri.bootstrap_room = 0
|
||||
else:
|
||||
gri = EmbeddingReqInput(
|
||||
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
|
||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||
break
|
||||
|
||||
# This request is a special request.
|
||||
# If the server already has something running, this request will be ignored, so it creates zero overhead.
|
||||
# If the server is not running, this request will be run, so we know whether the server is healthy.
|
||||
task = asyncio.create_task(gen())
|
||||
|
||||
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
||||
@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
|
||||
else:
|
||||
logger.info(f"Start of prefill warmup ...")
|
||||
logger.info(f"Start of pd disaggregation warmup ...")
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
|
||||
headers=headers,
|
||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||
)
|
||||
logger.info(
|
||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
if res.status_code == 200:
|
||||
logger.info(
|
||||
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
else:
|
||||
logger.info(
|
||||
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
|
||||
res.status_code
|
||||
)
|
||||
)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
|
||||
Reference in New Issue
Block a user