Simplify health check (#9034)
This commit is contained in:
@@ -269,10 +269,9 @@ class TokenizerManager:
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# Health check
|
||||
self.health_check_failed = False
|
||||
self.server_status = ServerStatus.Starting
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.server_status = ServerStatus.Starting
|
||||
|
||||
# Dumping
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
@@ -291,8 +290,8 @@ class TokenizerManager:
|
||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||
None
|
||||
)
|
||||
self._is_updating = False
|
||||
self._is_updating_cond = asyncio.Condition()
|
||||
self.is_pause = False
|
||||
self.is_pause_cond = asyncio.Condition()
|
||||
|
||||
# LoRA
|
||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||
@@ -476,15 +475,15 @@ class TokenizerManager:
|
||||
self.auto_create_handle_loop()
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
async with self._is_updating_cond:
|
||||
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||
)
|
||||
|
||||
async with self.is_pause_cond:
|
||||
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
if obj.is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
@@ -982,14 +981,14 @@ class TokenizerManager:
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def pause_generation(self):
|
||||
async with self._is_updating_cond:
|
||||
self._is_updating = True
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = True
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
async def continue_generation(self):
|
||||
async with self._is_updating_cond:
|
||||
self._is_updating = False
|
||||
self._is_updating_cond.notify_all()
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = False
|
||||
self.is_pause_cond.notify_all()
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
@@ -1474,7 +1473,7 @@ class TokenizerManager:
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
|
||||
if self.health_check_failed:
|
||||
if self.server_status == ServerStatus.UnHealthy:
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
@@ -1965,10 +1964,6 @@ class ServerStatus(Enum):
|
||||
Up = "Up"
|
||||
Starting = "Starting"
|
||||
UnHealthy = "UnHealthy"
|
||||
Crashed = "Crashed"
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
return self == ServerStatus.Up
|
||||
|
||||
|
||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||
|
||||
Reference in New Issue
Block a user