Revert "[Feature] Simple Improve Health Check Mechanism for Production-Grade Stability" (#8181)
This commit is contained in:
@@ -65,7 +65,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
ServerStatus,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
@@ -74,7 +73,6 @@ from sglang.srt.utils import (
|
||||
launch_dummy_health_check_server,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
report_health,
|
||||
set_prometheus_multiproc_dir,
|
||||
set_ulimit,
|
||||
)
|
||||
@@ -663,7 +661,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
def sigchld_handler(signum, frame):
|
||||
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
||||
if exitcode != 0:
|
||||
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
|
||||
logger.warning(
|
||||
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
|
||||
)
|
||||
@@ -677,7 +674,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
logger.error(
|
||||
"Received sigquit from a child process. It usually means the child failed."
|
||||
)
|
||||
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
signal.signal(signal.SIGQUIT, sigquit_handler)
|
||||
|
||||
@@ -77,7 +77,6 @@ from sglang.srt.managers.io_struct import (
|
||||
ParseFunctionCallReq,
|
||||
ProfileReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReportHealthInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
SeparateReasoningReqInput,
|
||||
SetInternalStateReq,
|
||||
@@ -94,7 +93,6 @@ from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
ServerStatus,
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
delete_directory,
|
||||
@@ -222,31 +220,8 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Check the status of the http server."""
|
||||
code = HTTPStatus.SERVICE_UNAVAILABLE.value
|
||||
if _global_state.tokenizer_manager.server_status == ServerStatus.Up:
|
||||
code = HTTPStatus.OK.value
|
||||
return Response(
|
||||
status_code=code,
|
||||
content=json.dumps(
|
||||
{"status": _global_state.tokenizer_manager.server_status.value}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/health")
|
||||
async def health_update(obj: ReportHealthInput, request: Request) -> Response:
|
||||
"""Update the Status of the http server."""
|
||||
try:
|
||||
server_status = ServerStatus(obj.status)
|
||||
_global_state.tokenizer_manager.server_status = server_status
|
||||
if server_status != ServerStatus.Up:
|
||||
return Response(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value)
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
@@ -281,7 +256,7 @@ async def health_generate(request: Request) -> Response:
|
||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||
task.cancel()
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
_global_state.tokenizer_manager.health_check_failed = False
|
||||
return Response(status_code=200)
|
||||
|
||||
task.cancel()
|
||||
@@ -295,7 +270,7 @@ async def health_generate(request: Request) -> Response:
|
||||
f"last_heartbeat time: {last_receive_time}"
|
||||
)
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||
_global_state.tokenizer_manager.health_check_failed = True
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@@ -1047,13 +1022,9 @@ def _execute_server_warmup(
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
if res.status_code == 200:
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
else:
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||
logger.info(f"{res}")
|
||||
assert res.status_code == 200, f"{res}"
|
||||
else:
|
||||
logger.info(f"Start of prefill/decode warmup ...")
|
||||
logger.info(f"Start of prefill warmup ...")
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
@@ -1075,25 +1046,15 @@ def _execute_server_warmup(
|
||||
headers=headers,
|
||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||
)
|
||||
if res.status_code == 200:
|
||||
logger.info(
|
||||
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
else:
|
||||
logger.info(
|
||||
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
|
||||
res.status_code
|
||||
)
|
||||
)
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||
logger.info(
|
||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Crashed
|
||||
kill_process_tree(os.getpid())
|
||||
return False
|
||||
|
||||
|
||||
@@ -1083,9 +1083,3 @@ class LoRAUpdateResult:
|
||||
|
||||
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportHealthInput:
|
||||
status: str
|
||||
msg: Optional[str] = ""
|
||||
|
||||
@@ -143,7 +143,6 @@ from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
DynamicGradMode,
|
||||
ServerStatus,
|
||||
broadcast_pyobj,
|
||||
configure_gc_logger,
|
||||
configure_logger,
|
||||
@@ -155,7 +154,6 @@ from sglang.srt.utils import (
|
||||
kill_itself_when_parent_died,
|
||||
point_to_point_pyobj,
|
||||
pyspy_dump_schedulers,
|
||||
report_health,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
set_gpu_proc_affinity,
|
||||
@@ -2966,5 +2964,4 @@ def run_scheduler_process(
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"Scheduler hit an exception: {traceback}")
|
||||
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
|
||||
parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
@@ -116,7 +116,6 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
ServerStatus,
|
||||
dataclass_to_string_truncated,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
@@ -174,9 +173,6 @@ class TokenizerManager:
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Server Status
|
||||
self.server_status = ServerStatus.Starting
|
||||
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
@@ -255,6 +251,7 @@ class TokenizerManager:
|
||||
# Store states
|
||||
self.no_create_loop = False
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.health_check_failed = False
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
@@ -1335,7 +1332,7 @@ class TokenizerManager:
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
|
||||
if not self.server_status.is_healthy():
|
||||
if self.health_check_failed:
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
|
||||
@@ -93,22 +93,6 @@ time_infos = {}
|
||||
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
||||
|
||||
|
||||
class ServerStatus(Enum):
|
||||
Up = "Up"
|
||||
Starting = "Starting"
|
||||
UnHealthy = "UnHealthy"
|
||||
Crashed = "Crashed"
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
return self == ServerStatus.Up
|
||||
|
||||
|
||||
def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""):
|
||||
requests.post(
|
||||
f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg}
|
||||
)
|
||||
|
||||
|
||||
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
Reference in New Issue
Block a user