Simplify health check (#9034)

This commit is contained in:
Lianmin Zheng
2025-08-10 17:35:05 -07:00
committed by GitHub
parent dd949ace23
commit 4ea9d74a3e
3 changed files with 21 additions and 27 deletions

View File

@@ -269,10 +269,9 @@ class TokenizerManager:
self.asyncio_tasks = set()
# Health check
self.health_check_failed = False
self.server_status = ServerStatus.Starting
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.server_status = ServerStatus.Starting
# Dumping
self.dump_requests_folder = "" # By default do not dump
@@ -291,8 +290,8 @@ class TokenizerManager:
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self._is_updating = False
self._is_updating_cond = asyncio.Condition()
self.is_pause = False
self.is_pause_cond = asyncio.Condition()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
@@ -476,15 +475,15 @@ class TokenizerManager:
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
async with self._is_updating_cond:
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
)
async with self.is_pause_cond:
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
async with self.model_update_lock.reader_lock:
if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj)
@@ -982,14 +981,14 @@ class TokenizerManager:
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self):
async with self._is_updating_cond:
self._is_updating = True
async with self.is_pause_cond:
self.is_pause = True
self.abort_request(abort_all=True)
async def continue_generation(self):
async with self._is_updating_cond:
self._is_updating = False
self._is_updating_cond.notify_all()
async with self.is_pause_cond:
self.is_pause = False
self.is_pause_cond.notify_all()
async def update_weights_from_disk(
self,
@@ -1474,7 +1473,7 @@ class TokenizerManager:
while True:
remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
if self.server_status == ServerStatus.UnHealthy:
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
@@ -1965,10 +1964,6 @@ class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: