Disable graceful shutdown of tokenizer manager when not in the main thread (#2872)

This commit is contained in:
Cody Yu
2025-01-15 03:29:33 -08:00
committed by GitHub
parent bfbda62c8b
commit b803b395b7

View File

@@ -21,6 +21,7 @@ import os
import pickle import pickle
import signal import signal
import sys import sys
import threading
import time import time
import uuid import uuid
from datetime import datetime from datetime import datetime
@@ -265,10 +266,16 @@ class TokenizerManager:
) )
input_embeds = obj.input_embeds input_embeds = obj.input_embeds
input_ids = obj.input_ids input_ids = obj.input_ids
elif obj.input_ids is None: elif obj.input_ids is not None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids input_ids = obj.input_ids
else:
if self.tokenizer is None:
raise ValueError(
"The engine initialized with skip_tokenizer_init=True cannot "
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
input_ids = self.tokenizer.encode(input_text)
if self.is_generation: if self.is_generation:
# TODO: also support getting embeddings for multimodal models # TODO: also support getting embeddings for multimodal models
@@ -635,8 +642,17 @@ class TokenizerManager:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.asyncio_tasks.add(loop.create_task(self.handle_loop())) self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
# We cannot add signal handler when the tokenizer manager is not in
# the main thread due to the CPython limitation.
if threading.current_thread() is threading.main_thread():
signal_handler = SignalHandler(self) signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
else:
logger.warning(
"Signal handler is not added because the tokenizer manager is "
"not in the main thread. This disables graceful shutdown of the "
"tokenizer manager when SIGTERM is received."
)
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog())) self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
async def sigterm_watchdog(self): async def sigterm_watchdog(self):