Minor improvements of TokenizerManager / health check (#6327)
This commit is contained in:
@@ -129,7 +129,6 @@ from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
crash_on_warnings,
|
||||
disable_request_logging,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
@@ -169,6 +172,11 @@ class TokenizerManager:
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
self.log_requests = server_args.log_requests
|
||||
self.log_requests_level = server_args.log_requests_level
|
||||
self.preferred_sampling_params = (
|
||||
json.loads(server_args.preferred_sampling_params)
|
||||
if server_args.preferred_sampling_params
|
||||
else None
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -228,6 +236,7 @@ class TokenizerManager:
|
||||
# Store states
|
||||
self.no_create_loop = False
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.health_check_failed = False
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
@@ -255,6 +264,10 @@ class TokenizerManager:
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
||||
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
||||
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
||||
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
||||
)
|
||||
|
||||
# Communicators
|
||||
@@ -285,9 +298,13 @@ class TokenizerManager:
|
||||
self.start_profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -349,6 +366,10 @@ class TokenizerManager:
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
@@ -508,7 +529,14 @@ class TokenizerManager:
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
# Parse sampling parameters
|
||||
# Note: if there are preferred sampling params, we use them if they are not
|
||||
# explicitly passed in sampling_params
|
||||
if self.preferred_sampling_params:
|
||||
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
|
||||
else:
|
||||
sampling_kwargs = obj.sampling_params
|
||||
sampling_params = SamplingParams(**sampling_kwargs)
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
|
||||
@@ -667,7 +695,6 @@ class TokenizerManager:
|
||||
|
||||
generators = []
|
||||
rids = []
|
||||
|
||||
if getattr(obj, "parallel_sample_num", 1) == 1:
|
||||
if self.server_args.enable_tokenizer_batch_encode:
|
||||
# Validate batch tokenization constraints
|
||||
@@ -857,7 +884,7 @@ class TokenizerManager:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
@@ -946,6 +973,14 @@ class TokenizerManager:
|
||||
# Many DP ranks
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
responses: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
@@ -1015,11 +1050,17 @@ class TokenizerManager:
|
||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||
)
|
||||
|
||||
self.event_loop = loop
|
||||
|
||||
# We cannot add signal handler when the tokenizer manager is not in
|
||||
# the main thread due to the CPython limitation.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal_handler = SignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
||||
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
||||
loop.add_signal_handler(
|
||||
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Signal handler is not added because the tokenizer manager is "
|
||||
@@ -1037,6 +1078,15 @@ class TokenizerManager:
|
||||
# Drain requests
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
|
||||
if self.health_check_failed:
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
remain_num_req,
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
||||
)
|
||||
@@ -1120,7 +1170,16 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
raise NotImplementedError()
|
||||
if isinstance(recv_obj.outputs[i], str):
|
||||
out_dict = {
|
||||
"text": recv_obj.outputs[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
else:
|
||||
out_dict = {
|
||||
"outputs": json.dumps(recv_obj.outputs[i]),
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -1366,12 +1425,18 @@ class SignalHandler:
|
||||
def __init__(self, tokenizer_manager: TokenizerManager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
|
||||
def signal_handler(self, signum=None, frame=None):
|
||||
def sigterm_handler(self, signum=None, frame=None):
|
||||
logger.warning(
|
||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||
)
|
||||
self.tokenizer_manager.gracefully_exit = True
|
||||
|
||||
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||
logger.error(
|
||||
"Received sigquit from a child process. It usually means the child failed."
|
||||
)
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user