[Improvements] Merge health check route (#8444)
Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
|
||||
# We need to remove the sync in the following function for overlap schedule.
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
self.maybe_send_health_check_signal()
|
||||
|
||||
def process_disagg_prefill_inflight_queue(
|
||||
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
||||
|
||||
@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate(request: Request) -> Response:
|
||||
"""Check the health of the inference server by generating one token."""
|
||||
"""
|
||||
Check the health of the inference server by sending a special request to generate one token.
|
||||
|
||||
If the server is running something, this request will be ignored, so it creates zero overhead.
|
||||
If the server is not running anything, this request will be run, so we know whether the server is healthy.
|
||||
"""
|
||||
|
||||
if _global_state.tokenizer_manager.gracefully_exit:
|
||||
logger.info("Health check request received during shutdown. Returning 503.")
|
||||
return Response(status_code=503)
|
||||
|
||||
if not _global_state.tokenizer_manager.server_status.is_healthy():
|
||||
return Response(status_code=503)
|
||||
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||
rid = f"HEALTH_CHECK_{time.time()}"
|
||||
|
||||
if _global_state.tokenizer_manager.is_image_gen:
|
||||
raise NotImplementedError()
|
||||
# Keep this branch for some internal use cases.
|
||||
raise NotImplementedError("Image generation is not supported yet.")
|
||||
elif _global_state.tokenizer_manager.is_generation:
|
||||
gri = GenerateReqInput(
|
||||
rid=rid,
|
||||
@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
|
||||
sampling_params=sampling_params,
|
||||
log_metrics=False,
|
||||
)
|
||||
if (
|
||||
_global_state.tokenizer_manager.server_args.disaggregation_mode
|
||||
!= DisaggregationMode.NULL
|
||||
):
|
||||
gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
gri.bootstrap_room = 0
|
||||
else:
|
||||
gri = EmbeddingReqInput(
|
||||
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
|
||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||
break
|
||||
|
||||
# This request is a special request.
|
||||
# If the server already has something running, this request will be ignored, so it creates zero overhead.
|
||||
# If the server is not running, this request will be run, so we know whether the server is healthy.
|
||||
task = asyncio.create_task(gen())
|
||||
|
||||
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
||||
@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
|
||||
else:
|
||||
logger.info(f"Start of prefill warmup ...")
|
||||
logger.info(f"Start of pd disaggregation warmup ...")
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
|
||||
headers=headers,
|
||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||
)
|
||||
logger.info(
|
||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
if res.status_code == 200:
|
||||
logger.info(
|
||||
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
else:
|
||||
logger.info(
|
||||
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
|
||||
res.status_code
|
||||
)
|
||||
)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
|
||||
@@ -1781,6 +1781,9 @@ class Scheduler(
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
self.maybe_send_health_check_signal()
|
||||
|
||||
def maybe_send_health_check_signal(self):
|
||||
if self.return_health_check_ct:
|
||||
# Return some signal for the health check.
|
||||
# This is used to prevent the health check signal being blocked by long context prefill.
|
||||
|
||||
@@ -29,6 +29,7 @@ import uuid
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -115,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
||||
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -270,6 +272,7 @@ class TokenizerManager:
|
||||
self.health_check_failed = False
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.server_status = ServerStatus.Starting
|
||||
|
||||
# Dumping
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
@@ -1804,6 +1807,8 @@ class TokenizerManager:
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
if is_health_check_generate_req(recv_obj):
|
||||
return
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
state.finished = True
|
||||
if recv_obj.finished_reason:
|
||||
@@ -1938,6 +1943,16 @@ class TokenizerManager:
|
||||
return scores
|
||||
|
||||
|
||||
class ServerStatus(Enum):
|
||||
Up = "Up"
|
||||
Starting = "Starting"
|
||||
UnHealthy = "UnHealthy"
|
||||
Crashed = "Crashed"
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
return self == ServerStatus.Up
|
||||
|
||||
|
||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||
is_cross_node = server_args.dist_init_addr
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict, defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from importlib.util import find_spec
|
||||
@@ -93,6 +92,7 @@ logger = logging.getLogger(__name__)
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
|
||||
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user