diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 462727fff..c15c1eff0 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin: # We need to remove the sync in the following function for overlap schedule. self.set_next_batch_sampling_info_done(batch) + self.maybe_send_health_check_signal() def process_disagg_prefill_inflight_queue( self: Scheduler, rids_to_check: Optional[List[str]] = None diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index b58987bcb..180d33820 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang.srt.disaggregation.utils import ( FAKE_BOOTSTRAP_HOST, + DisaggregationMode, register_disaggregation_server, ) from sglang.srt.entrypoints.engine import _launch_subprocesses @@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import ( VertexGenerateReqInput, ) from sglang.srt.managers.template_manager import TemplateManager -from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs @@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request): @app.get("/health") -async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) - - @app.get("/health_generate") async def health_generate(request: Request) -> Response: - """Check the health of the inference server by generating one token.""" + """ + Check the health of the inference server by sending a special request to generate one token. + + If the server is running something, this request will be ignored, so it creates zero overhead. + If the server is not running anything, this request will be run, so we know whether the server is healthy. + """ + if _global_state.tokenizer_manager.gracefully_exit: logger.info("Health check request received during shutdown. Returning 503.") return Response(status_code=503) + if not _global_state.tokenizer_manager.server_status.is_healthy(): + return Response(status_code=503) + sampling_params = {"max_new_tokens": 1, "temperature": 0.0} rid = f"HEALTH_CHECK_{time.time()}" if _global_state.tokenizer_manager.is_image_gen: - raise NotImplementedError() + # Keep this branch for some internal use cases. + raise NotImplementedError("Image generation is not supported yet.") elif _global_state.tokenizer_manager.is_generation: gri = GenerateReqInput( rid=rid, @@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response: sampling_params=sampling_params, log_metrics=False, ) + if ( + _global_state.tokenizer_manager.server_args.disaggregation_mode + != DisaggregationMode.NULL + ): + gri.bootstrap_host = FAKE_BOOTSTRAP_HOST + gri.bootstrap_room = 0 else: gri = EmbeddingReqInput( rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False @@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response: async for _ in _global_state.tokenizer_manager.generate_request(gri, request): break - # This request is a special request. - # If the server already has something running, this request will be ignored, so it creates zero overhead. - # If the server is not running, this request will be run, so we know whether the server is healthy. task = asyncio.create_task(gen()) # As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy. @@ -1032,8 +1041,10 @@ def _execute_server_warmup( timeout=600, ) assert res.status_code == 200, f"{res}" + _global_state.tokenizer_manager.server_status = ServerStatus.Up + else: - logger.info(f"Start of prefill warmup ...") + logger.info(f"Start of pd disaggregation warmup ...") json_data = { "sampling_params": { "temperature": 0.0, @@ -1055,9 +1066,18 @@ def _execute_server_warmup( headers=headers, timeout=1800, # because of deep gemm precache is very long if not precache. ) - logger.info( - f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" - ) + if res.status_code == 200: + logger.info( + f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}" + ) + _global_state.tokenizer_manager.server_status = ServerStatus.Up + else: + logger.info( + "Prefill disaggregation mode warm Up Failed, status code: {}".format( + res.status_code + ) + ) + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy except Exception: last_traceback = get_exception_traceback() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0249acd8d..5f9b7f20f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1781,6 +1781,9 @@ class Scheduler( elif batch.forward_mode.is_dummy_first(): self.set_next_batch_sampling_info_done(batch) + self.maybe_send_health_check_signal() + + def maybe_send_health_check_signal(self): if self.return_health_check_ct: # Return some signal for the health check. # This is used to prevent the health check signal being blocked by long context prefill. diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 89326bf06..cbd1c7332 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -29,6 +29,7 @@ import uuid from collections import deque from contextlib import nullcontext from datetime import datetime +from enum import Enum from http import HTTPStatus from typing import ( Any, @@ -115,6 +116,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors +from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams @@ -270,6 +272,7 @@ class TokenizerManager: self.health_check_failed = False self.gracefully_exit = False self.last_receive_tstamp = 0 + self.server_status = ServerStatus.Starting # Dumping self.dump_requests_folder = "" # By default do not dump @@ -1804,6 +1807,8 @@ class TokenizerManager: asyncio.create_task(asyncio.to_thread(background_task)) def _handle_abort_req(self, recv_obj): + if is_health_check_generate_req(recv_obj): + return state = self.rid_to_state[recv_obj.rid] state.finished = True if recv_obj.finished_reason: @@ -1938,6 +1943,16 @@ class TokenizerManager: return scores +class ServerStatus(Enum): + Up = "Up" + Starting = "Starting" + UnHealthy = "UnHealthy" + Crashed = "Crashed" + + def is_healthy(self) -> bool: + return self == ServerStatus.Up + + def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: is_cross_node = server_args.dist_init_addr diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index cc1ed8431..db841b3fd 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -44,7 +44,6 @@ import traceback import warnings from collections import OrderedDict, defaultdict from contextlib import contextmanager -from enum import Enum from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec @@ -93,6 +92,7 @@ logger = logging.getLogger(__name__) show_time_cost = False time_infos = {} + HIP_FP8_E4M3_FNUZ_MAX = 224.0