Revert "[Feature] Simple Improve Health Check Mechanism for Production-Grade Stability" (#8181)
This commit is contained in:
@@ -65,7 +65,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
MultiprocessingSerializer,
|
MultiprocessingSerializer,
|
||||||
ServerStatus,
|
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
@@ -74,7 +73,6 @@ from sglang.srt.utils import (
|
|||||||
launch_dummy_health_check_server,
|
launch_dummy_health_check_server,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
prepare_model_and_tokenizer,
|
prepare_model_and_tokenizer,
|
||||||
report_health,
|
|
||||||
set_prometheus_multiproc_dir,
|
set_prometheus_multiproc_dir,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
@@ -663,7 +661,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
def sigchld_handler(signum, frame):
|
def sigchld_handler(signum, frame):
|
||||||
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
||||||
if exitcode != 0:
|
if exitcode != 0:
|
||||||
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
|
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
|
||||||
)
|
)
|
||||||
@@ -677,7 +674,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
logger.error(
|
logger.error(
|
||||||
"Received sigquit from a child process. It usually means the child failed."
|
"Received sigquit from a child process. It usually means the child failed."
|
||||||
)
|
)
|
||||||
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
|
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
signal.signal(signal.SIGQUIT, sigquit_handler)
|
signal.signal(signal.SIGQUIT, sigquit_handler)
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ParseFunctionCallReq,
|
ParseFunctionCallReq,
|
||||||
ProfileReqInput,
|
ProfileReqInput,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ReportHealthInput,
|
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
SeparateReasoningReqInput,
|
SeparateReasoningReqInput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
@@ -94,7 +93,6 @@ from sglang.srt.metrics.func_timer import enable_func_timer
|
|||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
ServerStatus,
|
|
||||||
add_api_key_middleware,
|
add_api_key_middleware,
|
||||||
add_prometheus_middleware,
|
add_prometheus_middleware,
|
||||||
delete_directory,
|
delete_directory,
|
||||||
@@ -222,31 +220,8 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
|||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Response:
|
async def health() -> Response:
|
||||||
"""Check the status of the http server."""
|
"""Check the health of the http server."""
|
||||||
code = HTTPStatus.SERVICE_UNAVAILABLE.value
|
return Response(status_code=200)
|
||||||
if _global_state.tokenizer_manager.server_status == ServerStatus.Up:
|
|
||||||
code = HTTPStatus.OK.value
|
|
||||||
return Response(
|
|
||||||
status_code=code,
|
|
||||||
content=json.dumps(
|
|
||||||
{"status": _global_state.tokenizer_manager.server_status.value}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/health")
|
|
||||||
async def health_update(obj: ReportHealthInput, request: Request) -> Response:
|
|
||||||
"""Update the Status of the http server."""
|
|
||||||
try:
|
|
||||||
server_status = ServerStatus(obj.status)
|
|
||||||
_global_state.tokenizer_manager.server_status = server_status
|
|
||||||
if server_status != ServerStatus.Up:
|
|
||||||
return Response(
|
|
||||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health_generate")
|
@app.get("/health_generate")
|
||||||
@@ -281,7 +256,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
_global_state.tokenizer_manager.health_check_failed = False
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -295,7 +270,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
f"last_heartbeat time: {last_receive_time}"
|
f"last_heartbeat time: {last_receive_time}"
|
||||||
)
|
)
|
||||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
_global_state.tokenizer_manager.health_check_failed = True
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
@@ -1047,13 +1022,9 @@ def _execute_server_warmup(
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=600,
|
timeout=600,
|
||||||
)
|
)
|
||||||
if res.status_code == 200:
|
assert res.status_code == 200, f"{res}"
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
|
||||||
else:
|
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
|
||||||
logger.info(f"{res}")
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Start of prefill/decode warmup ...")
|
logger.info(f"Start of prefill warmup ...")
|
||||||
json_data = {
|
json_data = {
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
@@ -1075,25 +1046,15 @@ def _execute_server_warmup(
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||||
)
|
)
|
||||||
if res.status_code == 200:
|
logger.info(
|
||||||
logger.info(
|
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||||
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
|
)
|
||||||
)
|
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
|
|
||||||
res.status_code
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
last_traceback = get_exception_traceback()
|
last_traceback = get_exception_traceback()
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(last_traceback)
|
pipe_finish_writer.send(last_traceback)
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
_global_state.tokenizer_manager.server_status = ServerStatus.Crashed
|
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -1083,9 +1083,3 @@ class LoRAUpdateResult:
|
|||||||
|
|
||||||
|
|
||||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReportHealthInput:
|
|
||||||
status: str
|
|
||||||
msg: Optional[str] = ""
|
|
||||||
|
|||||||
@@ -143,7 +143,6 @@ from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
DeepEPMode,
|
DeepEPMode,
|
||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
ServerStatus,
|
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_gc_logger,
|
configure_gc_logger,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
@@ -155,7 +154,6 @@ from sglang.srt.utils import (
|
|||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
point_to_point_pyobj,
|
point_to_point_pyobj,
|
||||||
pyspy_dump_schedulers,
|
pyspy_dump_schedulers,
|
||||||
report_health,
|
|
||||||
require_mlp_sync,
|
require_mlp_sync,
|
||||||
require_mlp_tp_gather,
|
require_mlp_tp_gather,
|
||||||
set_gpu_proc_affinity,
|
set_gpu_proc_affinity,
|
||||||
@@ -2966,5 +2964,4 @@ def run_scheduler_process(
|
|||||||
except Exception:
|
except Exception:
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
logger.error(f"Scheduler hit an exception: {traceback}")
|
logger.error(f"Scheduler hit an exception: {traceback}")
|
||||||
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
|
|
||||||
parent_process.send_signal(signal.SIGQUIT)
|
parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
|||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
ServerStatus,
|
|
||||||
dataclass_to_string_truncated,
|
dataclass_to_string_truncated,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
@@ -174,9 +173,6 @@ class TokenizerManager:
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
):
|
):
|
||||||
# Server Status
|
|
||||||
self.server_status = ServerStatus.Starting
|
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
@@ -255,6 +251,7 @@ class TokenizerManager:
|
|||||||
# Store states
|
# Store states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
self.health_check_failed = False
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
self.last_receive_tstamp = 0
|
self.last_receive_tstamp = 0
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
@@ -1335,7 +1332,7 @@ class TokenizerManager:
|
|||||||
while True:
|
while True:
|
||||||
remain_num_req = len(self.rid_to_state)
|
remain_num_req = len(self.rid_to_state)
|
||||||
|
|
||||||
if not self.server_status.is_healthy():
|
if self.health_check_failed:
|
||||||
# if health check failed, we should exit immediately
|
# if health check failed, we should exit immediately
|
||||||
logger.error(
|
logger.error(
|
||||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||||
|
|||||||
@@ -93,22 +93,6 @@ time_infos = {}
|
|||||||
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
||||||
|
|
||||||
|
|
||||||
class ServerStatus(Enum):
|
|
||||||
Up = "Up"
|
|
||||||
Starting = "Starting"
|
|
||||||
UnHealthy = "UnHealthy"
|
|
||||||
Crashed = "Crashed"
|
|
||||||
|
|
||||||
def is_healthy(self) -> bool:
|
|
||||||
return self == ServerStatus.Up
|
|
||||||
|
|
||||||
|
|
||||||
def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""):
|
|
||||||
requests.post(
|
|
||||||
f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
||||||
def is_hip() -> bool:
|
def is_hip() -> bool:
|
||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user