[Production] Drain requests before exit when receive SIGTERM (#1838)

This commit is contained in:
Ying Sheng
2024-10-30 10:22:56 -07:00
committed by GitHub
parent 3184aa95a7
commit 4e2af03cfa

View File

@@ -20,6 +20,8 @@ import dataclasses
import json import json
import logging import logging
import os import os
import signal
import sys
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi import fastapi
@@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -142,6 +149,9 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock() self.model_update_lock = asyncio.Lock()
self.model_update_result = None self.model_update_result = None
# Others
self.gracefully_exit = False
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
@@ -629,6 +639,28 @@ class TokenizerManager:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop()) loop.create_task(self.handle_loop())
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog())
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(60)
# drain requests
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
else:
break
kill_child_process(include_self=True)
sys.exit(-1)
async def handle_loop(self): async def handle_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
@@ -740,3 +772,14 @@ class TokenizerManager:
token_top_logprobs, decode_to_text token_top_logprobs, decode_to_text
) )
return top_logprobs return top_logprobs
class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True