[Production] Drain requests before exit when receive SIGTERM (#1838)
This commit is contained in:
@@ -20,6 +20,8 @@ import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
@@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
|
||||
from sglang.srt.utils import (
|
||||
get_zmq_socket,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
kill_child_process,
|
||||
)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -142,6 +149,9 @@ class TokenizerManager:
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
@@ -629,6 +639,28 @@ class TokenizerManager:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
|
||||
signal_handler = SignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
||||
loop.create_task(self.sigterm_watchdog())
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
# drain requests
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
logger.info(
|
||||
f"gracefully exiting... remaining number of requests {remain_num_req}"
|
||||
)
|
||||
if remain_num_req > 0:
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
break
|
||||
|
||||
kill_child_process(include_self=True)
|
||||
sys.exit(-1)
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
@@ -740,3 +772,14 @@ class TokenizerManager:
|
||||
token_top_logprobs, decode_to_text
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
def __init__(self, tokenizer_manager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
|
||||
def signal_handler(self, signum=None, frame=None):
|
||||
logger.warning(
|
||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||
)
|
||||
self.tokenizer_manager.gracefully_exit = True
|
||||
|
||||
Reference in New Issue
Block a user