Minor improvements of TokenizerManager / health check (#6327)

This commit is contained in:
Lianmin Zheng
2025-05-15 15:29:25 -07:00
committed by GitHub
parent cd8d4b9dfc
commit e07a6977e7
9 changed files with 136 additions and 33 deletions

View File

@@ -129,7 +129,6 @@ from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj,
configure_logger,
crash_on_warnings,
disable_request_logging,
get_bool_env_var,
get_zmq_socket,

View File

@@ -16,6 +16,7 @@
import asyncio
import copy
import dataclasses
import json
import logging
import os
import pickle
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
@@ -169,6 +172,11 @@ class TokenizerManager:
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
self.log_requests_level = server_args.log_requests_level
self.preferred_sampling_params = (
json.loads(server_args.preferred_sampling_params)
if server_args.preferred_sampling_params
else None
)
# Init inter-process communication
context = zmq.asyncio.Context(2)
@@ -228,6 +236,7 @@ class TokenizerManager:
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.health_check_failed = False
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump
@@ -255,6 +264,10 @@ class TokenizerManager:
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
)
# Communicators
@@ -285,9 +298,13 @@ class TokenizerManager:
self.start_profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -349,6 +366,10 @@ class TokenizerManager:
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
@@ -508,7 +529,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature."
)
sampling_params = SamplingParams(**obj.sampling_params)
# Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
if self.preferred_sampling_params:
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
else:
sampling_kwargs = obj.sampling_params
sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
@@ -667,7 +695,6 @@ class TokenizerManager:
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) == 1:
if self.server_args.enable_tokenizer_batch_encode:
# Validate batch tokenization constraints
@@ -857,7 +884,7 @@ class TokenizerManager:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
), "dp_size must be 1 for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
@@ -946,6 +973,14 @@ class TokenizerManager:
# Many DP ranks
return [res.internal_state for res in responses]
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
responses: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return [res.internal_state for res in responses]
def get_log_request_metadata(self):
max_length = None
skip_names = None
@@ -1015,11 +1050,17 @@ class TokenizerManager:
loop.create_task(print_exception_wrapper(self.handle_loop))
)
self.event_loop = loop
# We cannot add signal handler when the tokenizer manager is not in
# the main thread due to the CPython limitation.
if threading.current_thread() is threading.main_thread():
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
loop.add_signal_handler(
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
)
else:
logger.warning(
"Signal handler is not added because the tokenizer manager is "
@@ -1037,6 +1078,15 @@ class TokenizerManager:
# Drain requests
while True:
remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
remain_num_req,
)
break
logger.info(
f"Gracefully exiting... remaining number of requests {remain_num_req}"
)
@@ -1120,7 +1170,16 @@ class TokenizerManager:
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchMultimodalOut):
raise NotImplementedError()
if isinstance(recv_obj.outputs[i], str):
out_dict = {
"text": recv_obj.outputs[i],
"meta_info": meta_info,
}
else:
out_dict = {
"outputs": json.dumps(recv_obj.outputs[i]),
"meta_info": meta_info,
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
@@ -1366,12 +1425,18 @@ class SignalHandler:
def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
def sigterm_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True
def running_phase_sigquit_handler(self, signum=None, frame=None):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree(os.getpid())
T = TypeVar("T")