[Production] Drain requests before exit when receive SIGTERM (#1838)
This commit is contained in:
@@ -20,6 +20,8 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
|
from sglang.srt.utils import (
|
||||||
|
get_zmq_socket,
|
||||||
|
is_generation_model,
|
||||||
|
is_multimodal_model,
|
||||||
|
kill_child_process,
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -142,6 +149,9 @@ class TokenizerManager:
|
|||||||
self.model_update_lock = asyncio.Lock()
|
self.model_update_lock = asyncio.Lock()
|
||||||
self.model_update_result = None
|
self.model_update_result = None
|
||||||
|
|
||||||
|
# Others
|
||||||
|
self.gracefully_exit = False
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
@@ -629,6 +639,28 @@ class TokenizerManager:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.create_task(self.handle_loop())
|
loop.create_task(self.handle_loop())
|
||||||
|
|
||||||
|
signal_handler = SignalHandler(self)
|
||||||
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
||||||
|
loop.create_task(self.sigterm_watchdog())
|
||||||
|
|
||||||
|
async def sigterm_watchdog(self):
|
||||||
|
while not self.gracefully_exit:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
# drain requests
|
||||||
|
while True:
|
||||||
|
remain_num_req = len(self.rid_to_state)
|
||||||
|
logger.info(
|
||||||
|
f"gracefully exiting... remaining number of requests {remain_num_req}"
|
||||||
|
)
|
||||||
|
if remain_num_req > 0:
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
kill_child_process(include_self=True)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
"""The event loop that handles requests"""
|
"""The event loop that handles requests"""
|
||||||
|
|
||||||
@@ -740,3 +772,14 @@ class TokenizerManager:
|
|||||||
token_top_logprobs, decode_to_text
|
token_top_logprobs, decode_to_text
|
||||||
)
|
)
|
||||||
return top_logprobs
|
return top_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
class SignalHandler:
|
||||||
|
def __init__(self, tokenizer_manager):
|
||||||
|
self.tokenizer_manager = tokenizer_manager
|
||||||
|
|
||||||
|
def signal_handler(self, signum=None, frame=None):
|
||||||
|
logger.warning(
|
||||||
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||||
|
)
|
||||||
|
self.tokenizer_manager.gracefully_exit = True
|
||||||
|
|||||||
Reference in New Issue
Block a user