Simplify health check (#9034)
This commit is contained in:
@@ -26,7 +26,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Callable, Dict, Optional
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -277,7 +277,7 @@ async def health_generate(request: Request) -> Response:
|
||||
logger.info("Health check request received during shutdown. Returning 503.")
|
||||
return Response(status_code=503)
|
||||
|
||||
if not _global_state.tokenizer_manager.server_status.is_healthy():
|
||||
if _global_state.tokenizer_manager.server_status == ServerStatus.Starting:
|
||||
return Response(status_code=503)
|
||||
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||
@@ -317,7 +317,7 @@ async def health_generate(request: Request) -> Response:
|
||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||
task.cancel()
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
_global_state.tokenizer_manager.health_check_failed = False
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||
return Response(status_code=200)
|
||||
|
||||
task.cancel()
|
||||
@@ -331,7 +331,7 @@ async def health_generate(request: Request) -> Response:
|
||||
f"last_heartbeat time: {last_receive_time}"
|
||||
)
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
_global_state.tokenizer_manager.health_check_failed = True
|
||||
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
|
||||
@@ -99,25 +99,24 @@ class GenerateReqInput:
|
||||
stream: bool = False
|
||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||
log_metrics: bool = True
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: Union[List[bool], bool] = False
|
||||
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||
|
||||
# The path to the LoRA adaptors
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
||||
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||
|
||||
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||
# Use the processor's `to_str()` method to generate the serialized string.
|
||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: Union[List[bool], bool] = False
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||
|
||||
@@ -269,10 +269,9 @@ class TokenizerManager:
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# Health check
|
||||
self.health_check_failed = False
|
||||
self.server_status = ServerStatus.Starting
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.server_status = ServerStatus.Starting
|
||||
|
||||
# Dumping
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
@@ -291,8 +290,8 @@ class TokenizerManager:
|
||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||
None
|
||||
)
|
||||
self._is_updating = False
|
||||
self._is_updating_cond = asyncio.Condition()
|
||||
self.is_pause = False
|
||||
self.is_pause_cond = asyncio.Condition()
|
||||
|
||||
# LoRA
|
||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||
@@ -476,15 +475,15 @@ class TokenizerManager:
|
||||
self.auto_create_handle_loop()
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
async with self._is_updating_cond:
|
||||
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||
)
|
||||
|
||||
async with self.is_pause_cond:
|
||||
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
if obj.is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
@@ -982,14 +981,14 @@ class TokenizerManager:
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def pause_generation(self):
|
||||
async with self._is_updating_cond:
|
||||
self._is_updating = True
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = True
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
async def continue_generation(self):
|
||||
async with self._is_updating_cond:
|
||||
self._is_updating = False
|
||||
self._is_updating_cond.notify_all()
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = False
|
||||
self.is_pause_cond.notify_all()
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
@@ -1474,7 +1473,7 @@ class TokenizerManager:
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
|
||||
if self.health_check_failed:
|
||||
if self.server_status == ServerStatus.UnHealthy:
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
@@ -1965,10 +1964,6 @@ class ServerStatus(Enum):
|
||||
Up = "Up"
|
||||
Starting = "Starting"
|
||||
UnHealthy = "UnHealthy"
|
||||
Crashed = "Crashed"
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
return self == ServerStatus.Up
|
||||
|
||||
|
||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||
|
||||
Reference in New Issue
Block a user