[Improvements] Merge health check route (#8444)
Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
|
|
||||||
# We need to remove the sync in the following function for overlap schedule.
|
# We need to remove the sync in the following function for overlap schedule.
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
|
self.maybe_send_health_check_signal()
|
||||||
|
|
||||||
def process_disagg_prefill_inflight_queue(
|
def process_disagg_prefill_inflight_queue(
|
||||||
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|||||||
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
FAKE_BOOTSTRAP_HOST,
|
FAKE_BOOTSTRAP_HOST,
|
||||||
|
DisaggregationMode,
|
||||||
register_disaggregation_server,
|
register_disaggregation_server,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
VertexGenerateReqInput,
|
VertexGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.template_manager import TemplateManager
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Response:
|
|
||||||
"""Check the health of the http server."""
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health_generate")
|
@app.get("/health_generate")
|
||||||
async def health_generate(request: Request) -> Response:
|
async def health_generate(request: Request) -> Response:
|
||||||
"""Check the health of the inference server by generating one token."""
|
"""
|
||||||
|
Check the health of the inference server by sending a special request to generate one token.
|
||||||
|
|
||||||
|
If the server is running something, this request will be ignored, so it creates zero overhead.
|
||||||
|
If the server is not running anything, this request will be run, so we know whether the server is healthy.
|
||||||
|
"""
|
||||||
|
|
||||||
if _global_state.tokenizer_manager.gracefully_exit:
|
if _global_state.tokenizer_manager.gracefully_exit:
|
||||||
logger.info("Health check request received during shutdown. Returning 503.")
|
logger.info("Health check request received during shutdown. Returning 503.")
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
if not _global_state.tokenizer_manager.server_status.is_healthy():
|
||||||
|
return Response(status_code=503)
|
||||||
|
|
||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||||
rid = f"HEALTH_CHECK_{time.time()}"
|
rid = f"HEALTH_CHECK_{time.time()}"
|
||||||
|
|
||||||
if _global_state.tokenizer_manager.is_image_gen:
|
if _global_state.tokenizer_manager.is_image_gen:
|
||||||
raise NotImplementedError()
|
# Keep this branch for some internal use cases.
|
||||||
|
raise NotImplementedError("Image generation is not supported yet.")
|
||||||
elif _global_state.tokenizer_manager.is_generation:
|
elif _global_state.tokenizer_manager.is_generation:
|
||||||
gri = GenerateReqInput(
|
gri = GenerateReqInput(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
log_metrics=False,
|
log_metrics=False,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
_global_state.tokenizer_manager.server_args.disaggregation_mode
|
||||||
|
!= DisaggregationMode.NULL
|
||||||
|
):
|
||||||
|
gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||||
|
gri.bootstrap_room = 0
|
||||||
else:
|
else:
|
||||||
gri = EmbeddingReqInput(
|
gri = EmbeddingReqInput(
|
||||||
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
|
|||||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||||
break
|
break
|
||||||
|
|
||||||
# This request is a special request.
|
|
||||||
# If the server already has something running, this request will be ignored, so it creates zero overhead.
|
|
||||||
# If the server is not running, this request will be run, so we know whether the server is healthy.
|
|
||||||
task = asyncio.create_task(gen())
|
task = asyncio.create_task(gen())
|
||||||
|
|
||||||
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
||||||
@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
|
|||||||
timeout=600,
|
timeout=600,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200, f"{res}"
|
assert res.status_code == 200, f"{res}"
|
||||||
|
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Start of prefill warmup ...")
|
logger.info(f"Start of pd disaggregation warmup ...")
|
||||||
json_data = {
|
json_data = {
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||||
)
|
)
|
||||||
logger.info(
|
if res.status_code == 200:
|
||||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
logger.info(
|
||||||
)
|
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
|
||||||
|
)
|
||||||
|
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
|
||||||
|
res.status_code
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
last_traceback = get_exception_traceback()
|
last_traceback = get_exception_traceback()
|
||||||
|
|||||||
@@ -1781,6 +1781,9 @@ class Scheduler(
|
|||||||
elif batch.forward_mode.is_dummy_first():
|
elif batch.forward_mode.is_dummy_first():
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
|
|
||||||
|
self.maybe_send_health_check_signal()
|
||||||
|
|
||||||
|
def maybe_send_health_check_signal(self):
|
||||||
if self.return_health_check_ct:
|
if self.return_health_check_ct:
|
||||||
# Return some signal for the health check.
|
# Return some signal for the health check.
|
||||||
# This is used to prevent the health check signal being blocked by long context prefill.
|
# This is used to prevent the health check signal being blocked by long context prefill.
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import uuid
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@@ -115,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||||
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
||||||
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -270,6 +272,7 @@ class TokenizerManager:
|
|||||||
self.health_check_failed = False
|
self.health_check_failed = False
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
self.last_receive_tstamp = 0
|
self.last_receive_tstamp = 0
|
||||||
|
self.server_status = ServerStatus.Starting
|
||||||
|
|
||||||
# Dumping
|
# Dumping
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
@@ -1804,6 +1807,8 @@ class TokenizerManager:
|
|||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
def _handle_abort_req(self, recv_obj):
|
def _handle_abort_req(self, recv_obj):
|
||||||
|
if is_health_check_generate_req(recv_obj):
|
||||||
|
return
|
||||||
state = self.rid_to_state[recv_obj.rid]
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
state.finished = True
|
state.finished = True
|
||||||
if recv_obj.finished_reason:
|
if recv_obj.finished_reason:
|
||||||
@@ -1938,6 +1943,16 @@ class TokenizerManager:
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class ServerStatus(Enum):
|
||||||
|
Up = "Up"
|
||||||
|
Starting = "Starting"
|
||||||
|
UnHealthy = "UnHealthy"
|
||||||
|
Crashed = "Crashed"
|
||||||
|
|
||||||
|
def is_healthy(self) -> bool:
|
||||||
|
return self == ServerStatus.Up
|
||||||
|
|
||||||
|
|
||||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||||
is_cross_node = server_args.dist_init_addr
|
is_cross_node = server_args.dist_init_addr
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ import traceback
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
@@ -93,6 +92,7 @@ logger = logging.getLogger(__name__)
|
|||||||
show_time_cost = False
|
show_time_cost = False
|
||||||
time_infos = {}
|
time_infos = {}
|
||||||
|
|
||||||
|
|
||||||
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user