diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9a3e90969..9b990f11c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -20,6 +20,8 @@ import dataclasses import json import logging import os +import signal +import sys from typing import Dict, List, Optional, Tuple, Union import fastapi @@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model +from sglang.srt.utils import ( + get_zmq_socket, + is_generation_model, + is_multimodal_model, + kill_child_process, +) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -142,6 +149,9 @@ class TokenizerManager: self.model_update_lock = asyncio.Lock() self.model_update_result = None + # Others + self.gracefully_exit = False + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], @@ -629,6 +639,28 @@ class TokenizerManager: loop = asyncio.get_event_loop() loop.create_task(self.handle_loop()) + signal_handler = SignalHandler(self) + loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) + loop.create_task(self.sigterm_watchdog()) + + async def sigterm_watchdog(self): + while not self.gracefully_exit: + await asyncio.sleep(60) + + # drain requests + while True: + remain_num_req = len(self.rid_to_state) + logger.info( + f"gracefully exiting... remaining number of requests {remain_num_req}" + ) + if remain_num_req > 0: + await asyncio.sleep(5) + else: + break + + kill_child_process(include_self=True) + sys.exit(-1) + async def handle_loop(self): """The event loop that handles requests""" @@ -740,3 +772,14 @@ class TokenizerManager: token_top_logprobs, decode_to_text ) return top_logprobs + + +class SignalHandler: + def __init__(self, tokenizer_manager): + self.tokenizer_manager = tokenizer_manager + + def signal_handler(self, signum=None, frame=None): + logger.warning( + f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." + ) + self.tokenizer_manager.gracefully_exit = True