diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a6bcb0b5b..ce9936288 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -26,7 +26,7 @@ import os import threading import time from http import HTTPStatus -from typing import AsyncIterator, Callable, Dict, Optional +from typing import Any, AsyncIterator, Callable, Dict, List, Optional # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -277,7 +277,7 @@ async def health_generate(request: Request) -> Response: logger.info("Health check request received during shutdown. Returning 503.") return Response(status_code=503) - if not _global_state.tokenizer_manager.server_status.is_healthy(): + if _global_state.tokenizer_manager.server_status == ServerStatus.Starting: return Response(status_code=503) sampling_params = {"max_new_tokens": 1, "temperature": 0.0} @@ -317,7 +317,7 @@ async def health_generate(request: Request) -> Response: if _global_state.tokenizer_manager.last_receive_tstamp > tic: task.cancel() _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.health_check_failed = False + _global_state.tokenizer_manager.server_status = ServerStatus.Up return Response(status_code=200) task.cancel() @@ -331,7 +331,7 @@ async def health_generate(request: Request) -> Response: f"last_heartbeat time: {last_receive_time}" ) _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.health_check_failed = True + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy return Response(status_code=503) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 546128212..314339a8b 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -99,25 +99,24 @@ class GenerateReqInput: stream: bool = False # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) log_metrics: bool = True + # Whether to return hidden states + return_hidden_states: Union[List[bool], bool] = False # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None + # The path to the LoRA adaptors lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None # The uid of LoRA adaptors, should be initialized by tokenizer manager lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None - # Session info for continual prompting - session_params: Optional[Union[List[Dict], Dict]] = None - # Custom logit processor for advanced sampling control. Must be a serialized instance # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py # Use the processor's `to_str()` method to generate the serialized string. custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None - # Whether to return hidden states - return_hidden_states: Union[List[bool], bool] = False - # For disaggregated inference bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_port: Optional[Union[List[Optional[int]], int]] = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 50ac39f88..a1a81a87f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -269,10 +269,9 @@ class TokenizerManager: self.asyncio_tasks = set() # Health check - self.health_check_failed = False + self.server_status = ServerStatus.Starting self.gracefully_exit = False self.last_receive_tstamp = 0 - self.server_status = ServerStatus.Starting # Dumping self.dump_requests_folder = "" # By default do not dump @@ -291,8 +290,8 @@ class TokenizerManager: self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( None ) - self._is_updating = False - self._is_updating_cond = asyncio.Condition() + self.is_pause = False + self.is_pause_cond = asyncio.Condition() # LoRA # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. @@ -476,15 +475,15 @@ class TokenizerManager: self.auto_create_handle_loop() obj.normalize_batch_and_arguments() - async with self._is_updating_cond: - await self._is_updating_cond.wait_for(lambda: not self._is_updating) - if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" ) + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + async with self.model_update_lock.reader_lock: if obj.is_single: tokenized_obj = await self._tokenize_one_request(obj) @@ -982,14 +981,14 @@ class TokenizerManager: await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) async def pause_generation(self): - async with self._is_updating_cond: - self._is_updating = True + async with self.is_pause_cond: + self.is_pause = True self.abort_request(abort_all=True) async def continue_generation(self): - async with self._is_updating_cond: - self._is_updating = False - self._is_updating_cond.notify_all() + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() async def update_weights_from_disk( self, @@ -1474,7 +1473,7 @@ class TokenizerManager: while True: remain_num_req = len(self.rid_to_state) - if self.health_check_failed: + if self.server_status == ServerStatus.UnHealthy: # if health check failed, we should exit immediately logger.error( "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", @@ -1965,10 +1964,6 @@ class ServerStatus(Enum): Up = "Up" Starting = "Starting" UnHealthy = "UnHealthy" - Crashed = "Crashed" - - def is_healthy(self) -> bool: - return self == ServerStatus.Up def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: