Simplify health check (#9034)
This commit is contained in:
@@ -26,7 +26,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import AsyncIterator, Callable, Dict, Optional
|
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -277,7 +277,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
logger.info("Health check request received during shutdown. Returning 503.")
|
logger.info("Health check request received during shutdown. Returning 503.")
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
if not _global_state.tokenizer_manager.server_status.is_healthy():
|
if _global_state.tokenizer_manager.server_status == ServerStatus.Starting:
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||||
@@ -317,7 +317,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
_global_state.tokenizer_manager.health_check_failed = False
|
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -331,7 +331,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
f"last_heartbeat time: {last_receive_time}"
|
f"last_heartbeat time: {last_receive_time}"
|
||||||
)
|
)
|
||||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
_global_state.tokenizer_manager.health_check_failed = True
|
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -99,25 +99,24 @@ class GenerateReqInput:
|
|||||||
stream: bool = False
|
stream: bool = False
|
||||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||||
log_metrics: bool = True
|
log_metrics: bool = True
|
||||||
|
# Whether to return hidden states
|
||||||
|
return_hidden_states: Union[List[bool], bool] = False
|
||||||
|
|
||||||
# The modalities of the image data [image, multi-images, video]
|
# The modalities of the image data [image, multi-images, video]
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
|
# Session info for continual prompting
|
||||||
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
|
|
||||||
# The path to the LoRA adaptors
|
# The path to the LoRA adaptors
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
||||||
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
# Session info for continual prompting
|
|
||||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
|
||||||
|
|
||||||
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||||
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||||
# Use the processor's `to_str()` method to generate the serialized string.
|
# Use the processor's `to_str()` method to generate the serialized string.
|
||||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||||
|
|
||||||
# Whether to return hidden states
|
|
||||||
return_hidden_states: Union[List[bool], bool] = False
|
|
||||||
|
|
||||||
# For disaggregated inference
|
# For disaggregated inference
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||||
|
|||||||
@@ -269,10 +269,9 @@ class TokenizerManager:
|
|||||||
self.asyncio_tasks = set()
|
self.asyncio_tasks = set()
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
self.health_check_failed = False
|
self.server_status = ServerStatus.Starting
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
self.last_receive_tstamp = 0
|
self.last_receive_tstamp = 0
|
||||||
self.server_status = ServerStatus.Starting
|
|
||||||
|
|
||||||
# Dumping
|
# Dumping
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
@@ -291,8 +290,8 @@ class TokenizerManager:
|
|||||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
self._is_updating = False
|
self.is_pause = False
|
||||||
self._is_updating_cond = asyncio.Condition()
|
self.is_pause_cond = asyncio.Condition()
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||||
@@ -476,15 +475,15 @@ class TokenizerManager:
|
|||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
obj.normalize_batch_and_arguments()
|
obj.normalize_batch_and_arguments()
|
||||||
|
|
||||||
async with self._is_updating_cond:
|
|
||||||
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length, skip_names, _ = self.log_request_metadata
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with self.is_pause_cond:
|
||||||
|
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
||||||
|
|
||||||
async with self.model_update_lock.reader_lock:
|
async with self.model_update_lock.reader_lock:
|
||||||
if obj.is_single:
|
if obj.is_single:
|
||||||
tokenized_obj = await self._tokenize_one_request(obj)
|
tokenized_obj = await self._tokenize_one_request(obj)
|
||||||
@@ -982,14 +981,14 @@ class TokenizerManager:
|
|||||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||||
|
|
||||||
async def pause_generation(self):
|
async def pause_generation(self):
|
||||||
async with self._is_updating_cond:
|
async with self.is_pause_cond:
|
||||||
self._is_updating = True
|
self.is_pause = True
|
||||||
self.abort_request(abort_all=True)
|
self.abort_request(abort_all=True)
|
||||||
|
|
||||||
async def continue_generation(self):
|
async def continue_generation(self):
|
||||||
async with self._is_updating_cond:
|
async with self.is_pause_cond:
|
||||||
self._is_updating = False
|
self.is_pause = False
|
||||||
self._is_updating_cond.notify_all()
|
self.is_pause_cond.notify_all()
|
||||||
|
|
||||||
async def update_weights_from_disk(
|
async def update_weights_from_disk(
|
||||||
self,
|
self,
|
||||||
@@ -1474,7 +1473,7 @@ class TokenizerManager:
|
|||||||
while True:
|
while True:
|
||||||
remain_num_req = len(self.rid_to_state)
|
remain_num_req = len(self.rid_to_state)
|
||||||
|
|
||||||
if self.health_check_failed:
|
if self.server_status == ServerStatus.UnHealthy:
|
||||||
# if health check failed, we should exit immediately
|
# if health check failed, we should exit immediately
|
||||||
logger.error(
|
logger.error(
|
||||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||||
@@ -1965,10 +1964,6 @@ class ServerStatus(Enum):
|
|||||||
Up = "Up"
|
Up = "Up"
|
||||||
Starting = "Starting"
|
Starting = "Starting"
|
||||||
UnHealthy = "UnHealthy"
|
UnHealthy = "UnHealthy"
|
||||||
Crashed = "Crashed"
|
|
||||||
|
|
||||||
def is_healthy(self) -> bool:
|
|
||||||
return self == ServerStatus.Up
|
|
||||||
|
|
||||||
|
|
||||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||||
|
|||||||
Reference in New Issue
Block a user