diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index ddc405c48..febb827fa 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin: + len(self.disagg_decode_prealloc_queue.queue) == 0 ): - # When the server is idle, do self-check and re-init some states - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.self_check_during_idle() self.last_batch = batch @@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin: + len(self.disagg_decode_prealloc_queue.queue) == 0 ): - # When the server is idle, do self-check and re-init some states - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.self_check_during_idle() self.last_batch = batch self.last_batch_in_queue = last_batch_in_queue diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 8217bd44c..462727fff 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin: self.process_disagg_prefill_inflight_queue() if batch is None and len(self.disagg_prefill_inflight_queue) == 0: - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.self_check_during_idle() self.last_batch = batch # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it @@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin: self.process_disagg_prefill_inflight_queue() if batch is None and len(self.disagg_prefill_inflight_queue) == 0: - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.self_check_during_idle() self.last_batch = batch # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index c038e87fc..e52c546a0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs): "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) - def sigchld_handler(signum, frame): - pid, exitcode = os.waitpid(0, os.WNOHANG) - if exitcode != 0: - logger.warning( - f"Child process unexpectedly failed with {exitcode=}. {pid=}" + if True: # Keep this check for internal code compatibility + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + # Note: This sigquit handler is used in the launch phase, and may be replaced by + # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched. + def launch_phase_sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child process. It usually means the child failed." ) + kill_process_tree(os.getpid()) - signal.signal(signal.SIGCHLD, sigchld_handler) - - # Register the signal handler. - # The child processes will send SIGQUIT to this process when any error happens - # This process then clean up the whole process tree - def sigquit_handler(signum, frame): - logger.error( - "Received sigquit from a child process. It usually means the child failed." - ) - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) + signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler) # Set mp start method mp.set_start_method("spawn", force=True) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 586a26495..b58987bcb 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -238,6 +238,9 @@ async def health() -> Response: @app.get("/health_generate") async def health_generate(request: Request) -> Response: """Check the health of the inference server by generating one token.""" + if _global_state.tokenizer_manager.gracefully_exit: + logger.info("Health check request received during shutdown. Returning 503.") + return Response(status_code=503) sampling_params = {"max_new_tokens": 1, "temperature": 0.0} rid = f"HEALTH_CHECK_{time.time()}" @@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response: async for _ in _global_state.tokenizer_manager.generate_request(gri, request): break - tic = time.perf_counter() + # This request is a special request. + # If the server already has something running, this request will be ignored, so it creates zero overhead. + # If the server is not running, this request will be run, so we know whether the server is healthy. task = asyncio.create_task(gen()) - while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT: + + # As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy. + tic = time.time() + while time.time() < tic + HEALTH_CHECK_TIMEOUT: await asyncio.sleep(1) if _global_state.tokenizer_manager.last_receive_tstamp > tic: task.cancel() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 773e0c57d..c8d325f9e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -152,8 +152,6 @@ class GenerateReqInput: else: self._normalize_batch_inputs() - self._validate_session_params() - def _validate_inputs(self): """Validate that the input configuration is valid.""" if ( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 656bf7684..38db5313a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,7 +13,6 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" -import datetime import faulthandler import logging import os @@ -21,11 +20,10 @@ import signal import sys import threading import time -from collections import defaultdict, deque +from collections import deque from concurrent import futures from dataclasses import dataclass from http import HTTPStatus -from pathlib import Path from types import SimpleNamespace from typing import Dict, List, Optional, Tuple, Union @@ -37,7 +35,6 @@ from torch.distributed import barrier from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constrained.base_grammar_backend import ( INVALID_GRAMMAR_OBJ, create_grammar_backend, @@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import ( DecodeTransferQueue, SchedulerDisaggregationDecodeMixin, ) -from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch from sglang.srt.disaggregation.prefill import ( PrefillBootstrapQueue, SchedulerDisaggregationPrefillMixin, @@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import ( GetInternalStateReq, GetInternalStateReqOutput, GetWeightsByNameReqInput, - GetWeightsByNameReqOutput, HealthCheckOutput, InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, - ProfileReqOutput, - ProfileReqType, ReleaseMemoryOccupationReqInput, - ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, - ResumeMemoryOccupationReqOutput, RpcReqInput, RpcReqOutput, SetInternalStateReq, @@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import ( UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, - UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.mm_utils import init_embedding_cache from sglang.srt.managers.schedule_batch import ( @@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import ( SchedulePolicy, ) from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker +from sglang.srt.managers.scheduler_metrics_mixin import ( + RECORD_STEP_TIME, + SchedulerMetricsMixin, +) from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) +from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin +from sglang.srt.managers.scheduler_update_weights_mixin import ( + SchedulerUpdateWeightsMixin, +) from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient @@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs @@ -168,7 +162,6 @@ logger = logging.getLogger(__name__) # Test retract decode for debugging purposes TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") -RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) _is_cpu = is_cpu() @@ -191,41 +184,11 @@ class EmbeddingBatchResult: bid: int -class KvMetrics: - def __init__(self): - self.request_active_slots = None - self.request_total_slots = None - self.kv_active_blocks = None - self.kv_total_blocks = None - self.num_requests_waiting = None - self.gpu_cache_usage_perc = None - self.gpu_prefix_cache_hit_rate = None - self.data_parallel_rank = None - - -class IdleSleeper: - """ - In setups which have long inactivity periods it is desirable to reduce - system power consumption when sglang does nothing. This would lead not only - to power savings, but also to more CPU thermal headroom when a request - eventually comes. This is important in cases when multiple GPUs are connected - as each GPU would otherwise pin one thread at 100% CPU usage. - - The simplest solution is to use zmq.Poller on all sockets that may receive - data that needs handling immediately. - """ - - def __init__(self, sockets): - self.poller = zmq.Poller() - for s in sockets: - self.poller.register(s, zmq.POLLIN) - - def maybe_sleep(self): - self.poller.poll(1000) - - class Scheduler( SchedulerOutputProcessorMixin, + SchedulerUpdateWeightsMixin, + SchedulerProfilerMixin, + SchedulerMetricsMixin, SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationPrefillMixin, ): @@ -266,7 +229,7 @@ class Scheduler( self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.page_size = server_args.page_size - self.dp_size = server_args.dp_size + self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( compute_dp_attention_world_info( server_args.enable_dp_attention, @@ -284,10 +247,13 @@ class Scheduler( self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) + self.recv_from_rpc = get_zmq_socket( + context, zmq.DEALER, port_args.rpc_ipc_name, False + ) + self.send_to_tokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) - if server_args.skip_tokenizer_init: # Directly send to the TokenizerManager self.send_to_detokenizer = get_zmq_socket( @@ -299,9 +265,6 @@ class Scheduler( context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) - self.recv_from_rpc = get_zmq_socket( - context, zmq.DEALER, port_args.rpc_ipc_name, False - ) if self.server_args.sleep_on_idle: self.idle_sleeper = IdleSleeper( [ @@ -398,7 +361,7 @@ class Scheduler( global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) - # Hybrid + # Hybrid memory pool self.is_hybrid = self.tp_worker.is_hybrid if self.is_hybrid: self.sliding_window_size = self.tp_worker.sliding_window_size @@ -515,6 +478,15 @@ class Scheduler( self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_kv_events(server_args.kv_events_config) + # Init disaggregation + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + self.init_disaggregation() + + if get_bool_env_var("SGLANG_GC_LOG"): + configure_gc_logger() + # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -545,22 +517,6 @@ class Scheduler( ] ) - # Init disaggregation - self.disaggregation_mode = DisaggregationMode( - self.server_args.disaggregation_mode - ) - self.init_disaggregation() - - if get_bool_env_var("SGLANG_GC_LOG"): - configure_gc_logger() - - def current_scheduler_metrics_enabled(self): - return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers - - def maybe_sleep_on_idle(self): - if self.idle_sleeper is not None: - self.idle_sleeper.maybe_sleep() - def init_tokenizer(self): server_args = self.server_args @@ -668,50 +624,6 @@ class Scheduler( embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100")) init_embedding_cache(embedding_cache_size * 1024 * 1024) - def init_profier(self): - self.torch_profiler = None - self.torch_profiler_output_dir: Optional[str] = None - self.profiler_activities: Optional[List[str]] = None - self.profile_id: Optional[str] = None - self.profiler_start_forward_ct: Optional[int] = None - self.profiler_target_forward_ct: Optional[int] = None - self.profiler_target_prefill_ct: Optional[int] = None - self.profiler_target_decode_ct: Optional[int] = None - self.profiler_prefill_ct: Optional[int] = None - self.profiler_decode_ct: Optional[int] = None - self.profile_by_stage: bool = False - self.profile_steps: Optional[int] = None - self.profile_in_progress: bool = False - self.rpd_profiler = None - - def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): - self.last_gen_throughput: float = 0.0 - self.last_input_throughput: float = 0.0 - self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] - self.spec_num_total_accepted_tokens = 0 - self.spec_num_total_forward_ct = 0 - self.cum_spec_accept_length = 0 - self.cum_spec_accept_count = 0 - self.total_retracted_reqs = 0 - self.stats = SchedulerStats() - if self.enable_metrics: - engine_type = "unified" - labels = { - "model_name": self.server_args.served_model_name, - "engine_type": engine_type, - "tp_rank": tp_rank, - "pp_rank": pp_rank, - } - if dp_rank is not None: - labels["dp_rank"] = dp_rank - self.metrics_collector = SchedulerMetricsCollector(labels=labels) - - def init_kv_events(self, kv_events_config: Optional[str]): - if self.enable_kv_cache_events: - self.kv_event_publisher = EventPublisherFactory.create( - kv_events_config, self.attn_dp_rank - ) - def init_disaggregation(self): self.transfer_backend = TransferBackend( self.server_args.disaggregation_transfer_backend @@ -820,10 +732,7 @@ class Scheduler( self.process_batch_result(batch, result) else: # When the server is idle, do self-check and re-init some states - self.check_memory() - self.check_tree_cache() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.self_check_during_idle() self.last_batch = batch @@ -866,10 +775,7 @@ class Scheduler( ) elif batch is None: # When the server is idle, do self-check and re-init some states - self.check_memory() - self.check_tree_cache() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.self_check_during_idle() self.last_batch = batch @@ -1003,10 +909,8 @@ class Scheduler( # When the server is idle, self-check and re-init some states if server_is_idle: - self.check_memory() - self.check_tree_cache() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + # When the server is idle, do self-check and re-init some states + self.self_check_during_idle() def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" @@ -1355,170 +1259,11 @@ class Scheduler( req.logprob_start_len = len(req.origin_input_ids) - 1 self._add_request_to_queue(req) - def _emit_kv_metrics(self): - kv_metrics = KvMetrics() - kv_metrics.request_active_slots = self.stats.num_running_reqs - kv_metrics.request_total_slots = self.max_running_requests - kv_metrics.kv_active_blocks = int( - self.stats.token_usage * self.max_total_num_tokens - ) - kv_metrics.kv_total_blocks = self.max_total_num_tokens - kv_metrics.num_requests_waiting = self.stats.num_queue_reqs - kv_metrics.gpu_cache_usage_perc = self.stats.token_usage - kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate - kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0 - - if not self.send_metrics_from_scheduler.closed: - self.send_metrics_from_scheduler.send_pyobj(kv_metrics) - - def log_prefill_stats( - self, - adder: PrefillAdder, - can_run_list: List[Req], - running_bs: int, - ): - gap_latency = time.perf_counter() - self.last_prefill_stats_tic - self.last_prefill_stats_tic = time.perf_counter() - self.last_input_throughput = self.last_prefill_tokens / gap_latency - self.last_prefill_tokens = adder.log_input_tokens - - if self.is_hybrid: - ( - full_num_used, - swa_num_used, - full_token_usage, - swa_token_usage, - _, - _, - _, - _, - ) = self._get_swa_token_info() - num_used = max(full_num_used, swa_num_used) - token_usage = max(full_token_usage, swa_token_usage) - token_msg = ( - f"full token usage: {full_token_usage:.2f}, " - f"swa token usage: {swa_token_usage:.2f}, " - ) - else: - num_used, token_usage, _, _ = self._get_token_info() - token_msg = f"token usage: {token_usage:.2f}, " - - num_new_seq = len(can_run_list) - f = ( - f"Prefill batch. " - f"#new-seq: {num_new_seq}, " - f"#new-token: {adder.log_input_tokens}, " - f"#cached-token: {adder.log_hit_tokens}, " - f"{token_msg}" - ) - - if self.disaggregation_mode == DisaggregationMode.PREFILL: - f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, " - f += f"#queue-req: {len(self.waiting_queue)}, " - f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, " - f += f"input throughput (token/s): {self.last_input_throughput:.2f}, " - else: - f += f"#running-req: {running_bs}, " - f += f"#queue-req: {len(self.waiting_queue)}, " - - logger.info(f) - - if self.enable_metrics: - total_tokens = adder.log_input_tokens + adder.log_hit_tokens - - cache_hit_rate = ( - adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0 - ) - self.stats.num_running_reqs = running_bs - self.stats.num_used_tokens = num_used - self.stats.token_usage = round(token_usage, 2) - self.stats.num_queue_reqs = len(self.waiting_queue) - self.stats.cache_hit_rate = cache_hit_rate - - total_queue_latency = 0 - for req in can_run_list: - total_queue_latency += req.queue_time_end - req.queue_time_start - self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq - - self.metrics_collector.log_stats(self.stats) - self._emit_kv_metrics() - self._publish_kv_events() - - def log_decode_stats( - self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None - ): - batch = running_batch or self.running_batch - - gap_latency = time.perf_counter() - self.last_decode_stats_tic - self.last_decode_stats_tic = time.perf_counter() - self.last_gen_throughput = self.num_generated_tokens / gap_latency - self.num_generated_tokens = 0 - num_running_reqs = len(batch.reqs) - if self.is_hybrid: - ( - full_num_used, - swa_num_used, - full_token_usage, - swa_token_usage, - _, - _, - _, - _, - ) = self._get_swa_token_info() - num_used = max(full_num_used, swa_num_used) - token_usage = max(full_token_usage, swa_token_usage) - token_msg = ( - f"#full token: {full_num_used}, " - f"full token usage: {full_token_usage:.2f}, " - f"#swa token: {swa_num_used}, " - f"swa token usage: {swa_token_usage:.2f}, " - ) - else: - num_used, token_usage, _, _ = self._get_token_info() - token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " - - if RECORD_STEP_TIME: - self.step_time_dict[num_running_reqs].append( - gap_latency / self.server_args.decode_log_interval - ) - - msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}" - - if self.spec_algorithm.is_none(): - spec_accept_length = 0 - else: - spec_accept_length = ( - self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct - ) - self.cum_spec_accept_length += self.spec_num_total_accepted_tokens - self.cum_spec_accept_count += self.spec_num_total_forward_ct - self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 - msg += f"accept len: {spec_accept_length:.2f}, " - - if self.disaggregation_mode == DisaggregationMode.DECODE: - msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " - msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " - - msg += ( - f"cuda graph: {can_run_cuda_graph}, " - f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}, " - ) - - logger.info(msg) - if self.enable_metrics: - self.stats.num_running_reqs = num_running_reqs - self.stats.num_used_tokens = num_used - self.stats.token_usage = round(token_usage, 2) - self.stats.cache_hit_rate = 0.0 - self.stats.gen_throughput = self.last_gen_throughput - self.stats.num_queue_reqs = len(self.waiting_queue) - self.stats.num_grammar_queue_reqs = len(self.grammar_queue) - self.stats.spec_accept_length = spec_accept_length - self.stats.total_retracted_reqs = self.total_retracted_reqs - self.metrics_collector.log_stats(self.stats) - self._emit_kv_metrics() - self._publish_kv_events() + def self_check_during_idle(self): + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle() def check_memory(self): if self.is_hybrid: @@ -2422,22 +2167,6 @@ class Scheduler( barrier() return RpcReqOutput(success, "" if not exec else str(exec)) - def save_remote_model(self, params): - url = params["url"] - - worker = self.tp_worker.worker - - worker.model_runner.save_remote_model(url) - - def save_sharded_model(self, params): - worker = self.tp_worker.worker - - worker.model_runner.save_sharded_model( - path=params["path"], - pattern=params["pattern"], - max_size=params["max_size"], - ) - def abort_request(self, recv_req: AbortReq): # Delete requests in the waiting queue to_del = [] @@ -2515,16 +2244,6 @@ class Scheduler( def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() - def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): - """In-place update of the weights from disk.""" - success, message = self.tp_worker.update_weights_from_disk(recv_req) - if success: - flush_cache_success = self.flush_cache() - assert flush_cache_success, "Cache flush failed after updating weights" - else: - logger.error(message) - return UpdateWeightFromDiskReqOutput(success, message, 0) - def load_lora_adapter( self, recv_req: LoadLoRAAdapterReqInput ) -> LoadLoRAAdapterReqOutput: @@ -2541,81 +2260,6 @@ class Scheduler( result = self.tp_worker.unload_lora_adapter(recv_req) return result - def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): - """Initialize the online model parameter update group.""" - success, message = self.tp_worker.init_weights_update_group(recv_req) - return InitWeightsUpdateGroupReqOutput(success, message) - - def update_weights_from_distributed( - self, - recv_req: UpdateWeightsFromDistributedReqInput, - ) -> Tuple[bool, str]: - """Update the online model parameter.""" - success, message = self.tp_worker.update_weights_from_distributed(recv_req) - if success: - if recv_req.flush_cache: - flush_cache_success = self.flush_cache() - assert flush_cache_success, "Cache flush failed after updating weights" - else: - logger.error(message) - return UpdateWeightsFromDistributedReqOutput(success, message) - - def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): - """Update the online model parameter from tensors.""" - success, message = self.tp_worker.update_weights_from_tensor(recv_req) - # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later - if success: - if recv_req.flush_cache: - flush_cache_success = self.flush_cache() - assert flush_cache_success, "Cache flush failed after updating weights" - else: - logger.error(message) - barrier(group=self.tp_cpu_group) - return UpdateWeightsFromTensorReqOutput(success, message) - - def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): - parameter = self.tp_worker.get_weights_by_name(recv_req) - return GetWeightsByNameReqOutput(parameter) - - def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): - tags = recv_req.tags - - if tags is None or len(tags) == 0: - tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] - - if GPU_MEMORY_TYPE_KV_CACHE in tags: - self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) - self.flush_cache() - - if GPU_MEMORY_TYPE_WEIGHTS in tags: - self.stashed_model_static_state = _export_static_state( - self.tp_worker.worker.model_runner.model - ) - torch.distributed.barrier(self.tp_cpu_group) - self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS) - - return ReleaseMemoryOccupationReqOutput() - - def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): - tags = recv_req.tags - - if tags is None or len(tags) == 0: - tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] - - if GPU_MEMORY_TYPE_WEIGHTS in tags: - self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) - torch.distributed.barrier(self.tp_cpu_group) - _import_static_state( - self.tp_worker.worker.model_runner.model, - self.stashed_model_static_state, - ) - del self.stashed_model_static_state - - if GPU_MEMORY_TYPE_KV_CACHE in tags: - self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) - - return ResumeMemoryOccupationReqOutput() - def slow_down(self, recv_req: SlowDownReqInput): t = recv_req.forward_sleep_time if t is not None and t <= 0: @@ -2623,254 +2267,6 @@ class Scheduler( self.forward_sleep_time = t return SlowDownReqOutput() - def profile(self, recv_req: ProfileReq): - if recv_req.type == ProfileReqType.START_PROFILE: - if recv_req.profile_by_stage or recv_req.start_step: - return self.init_profile( - recv_req.output_dir, - recv_req.start_step, - recv_req.num_steps, - recv_req.activities, - recv_req.with_stack, - recv_req.record_shapes, - recv_req.profile_by_stage, - recv_req.profile_id, - ) - else: - self.init_profile( - recv_req.output_dir, - recv_req.start_step, - recv_req.num_steps, - recv_req.activities, - recv_req.with_stack, - recv_req.record_shapes, - recv_req.profile_by_stage, - recv_req.profile_id, - ) - return self.start_profile(True) - else: - return self.stop_profile() - - def init_profile( - self, - output_dir: Optional[str], - start_step: Optional[int], - num_steps: Optional[int], - activities: Optional[List[str]], - with_stack: Optional[bool], - record_shapes: Optional[bool], - profile_by_stage: bool, - profile_id: str, - ) -> ProfileReqOutput: - if self.profile_in_progress: - return ProfileReqOutput( - success=False, - message="Profiling is already in progress. Call /stop_profile first.", - ) - - self.profile_by_stage = profile_by_stage - - if output_dir is None: - output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp") - if activities is None: - activities = ["CPU", "GPU"] - - self.torch_profiler_output_dir = output_dir - self.torch_profiler_with_stack = with_stack - self.torch_profiler_record_shapes = record_shapes - self.profiler_activities = activities - self.profile_id = profile_id - - if start_step: - self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1) - - if num_steps: - self.profile_steps = num_steps - if self.profile_by_stage: - self.profiler_target_prefill_ct = num_steps - self.profiler_target_decode_ct = num_steps - self.profiler_prefill_ct = 0 - self.profiler_decode_ct = 0 - elif start_step: - self.profiler_target_forward_ct = ( - self.profiler_start_forward_ct + num_steps - ) - else: - self.profiler_target_forward_ct = self.forward_ct + num_steps - # The caller will be notified when reaching profiler_target_forward_ct - else: - self.profiler_target_forward_ct = None - - return ProfileReqOutput(success=True, message="Succeeded") - - def start_profile( - self, stage: Optional[ForwardMode] = None - ) -> ProfileReqOutput | None: - stage_str = f" for {stage.__str__()}" if stage else "" - logger.info( - f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})", - ) - - activities = self.profiler_activities - with_stack = self.torch_profiler_with_stack - record_shapes = self.torch_profiler_record_shapes - - activity_map = { - "CPU": torch.profiler.ProfilerActivity.CPU, - "GPU": torch.profiler.ProfilerActivity.CUDA, - } - torchprof_activities = [ - activity_map[a] for a in activities if a in activity_map - ] - - if "RPD" in activities: - from rpdTracerControl import rpdTracerControl - - rpdTracerControl.skipCreate() - - self.rpd_profile_path = os.path.join( - self.torch_profiler_output_dir, - "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz", - ) - - if self.tp_rank == 0: - import sqlite3 - - from rocpd.schema import RocpdSchema - - if os.path.exists("trace.rpd"): - os.unlink("trace.rpd") - schema = RocpdSchema() - connection = sqlite3.connect("trace.rpd") - schema.writeSchema(connection) - connection.commit() - del connection - torch.distributed.barrier(self.tp_cpu_group) - - self.rpd_profiler = rpdTracerControl() - self.rpd_profiler.setPythonTrace(True) - self.rpd_profiler.start() - self.rpd_profiler.rangePush("", "rpd profile range", "") - self.profile_in_progress = True - elif torchprof_activities: - self.torch_profiler = torch.profiler.profile( - activities=torchprof_activities, - with_stack=with_stack if with_stack is not None else True, - record_shapes=record_shapes if record_shapes is not None else False, - ) - self.torch_profiler.start() - self.profile_in_progress = True - - if "MEM" in activities: - torch.cuda.memory._record_memory_history(max_entries=100000) - self.profile_in_progress = True - - if "CUDA_PROFILER" in activities: - torch.cuda.cudart().cudaProfilerStart() - self.profile_in_progress = True - - return ProfileReqOutput(success=True, message="Succeeded") - - def stop_profile( - self, stage: Optional[ForwardMode] = None - ) -> ProfileReqOutput | None: - if not self.profile_in_progress: - return ProfileReqOutput( - success=False, - message="Profiling is not in progress. Call /start_profile first.", - ) - - if not Path(self.torch_profiler_output_dir).exists(): - Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True) - - stage_suffix = f"-{stage.__str__()}" if stage else "" - logger.info("Stop profiling" + stage_suffix + "...") - if self.torch_profiler is not None: - self.torch_profiler.stop() - self.torch_profiler.export_chrome_trace( - os.path.join( - self.torch_profiler_output_dir, - self.profile_id - + f"-TP-{self.tp_rank}" - + stage_suffix - + ".trace.json.gz", - ) - ) - torch.distributed.barrier(self.tp_cpu_group) - - if self.rpd_profiler is not None: - self.rpd_profiler.rangePop() - self.rpd_profiler.stop() - self.rpd_profiler.flush() - - torch.distributed.barrier(self.tp_cpu_group) - if self.tp_rank == 0: - from sglang.srt.utils import rpd_to_chrome_trace - - rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path) - self.rpd_profiler = None - self.rpd_profiler_path = None - - if self.profiler_activities is not None and "MEM" in self.profiler_activities: - memory_profile_path = os.path.join( - self.torch_profiler_output_dir, - str(time.time()) - + f"-TP-{self.tp_rank}-memory" - + stage_suffix - + ".pickle", - ) - torch.cuda.memory._dump_snapshot(memory_profile_path) - torch.cuda.memory._record_memory_history(enabled=None) - - if "CUDA_PROFILER" in self.profiler_activities: - torch.cuda.cudart().cudaProfilerStop() - - logger.info( - "Profiling done. Traces are saved to: %s", - self.torch_profiler_output_dir, - ) - self.torch_profiler = None - self.profile_in_progress = False - self.profiler_start_forward_ct = None - - return ProfileReqOutput(success=True, message="Succeeded.") - - def _profile_batch_predicate(self, batch): - if self.profile_by_stage: - if batch.forward_mode.is_prefill(): - if self.profiler_prefill_ct == 0: - self.start_profile(batch.forward_mode) - self.profiler_prefill_ct += 1 - if self.profiler_prefill_ct > self.profiler_target_prefill_ct: - if self.profile_in_progress: - self.stop_profile(stage=ForwardMode.EXTEND) - elif batch.forward_mode.is_decode(): - if self.profiler_decode_ct == 0: - if self.profile_in_progress: - # force trace flush - self.stop_profile(ForwardMode.EXTEND) - self.start_profile(batch.forward_mode) - self.profiler_decode_ct += 1 - if self.profiler_decode_ct > self.profiler_target_decode_ct: - if self.profile_in_progress: - self.stop_profile(stage=ForwardMode.DECODE) - elif batch.forward_mode.is_idle(): - pass - else: - raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}") - else: - # Check profiler - if ( - self.profiler_target_forward_ct - and self.profiler_target_forward_ct <= self.forward_ct - ): - self.stop_profile() - if ( - self.profiler_start_forward_ct - and self.profiler_start_forward_ct == self.forward_ct - ): - self.start_profile() - def expert_distribution_handle(self, recv_req: ExpertDistributionReq): if recv_req == ExpertDistributionReq.START_RECORD: get_global_expert_distribution_recorder().start_record() @@ -2879,7 +2275,7 @@ class Scheduler( elif recv_req == ExpertDistributionReq.DUMP_RECORD: get_global_expert_distribution_recorder().dump_record() else: - raise ValueError("Unrecognized ExpertDistributionReq value") + raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}") return ExpertDistributionReqOutput() def open_session(self, recv_req: OpenSessionReqInput): @@ -2915,12 +2311,33 @@ class Scheduler( prefix += f" PP{self.pp_rank}" return prefix - def _publish_kv_events(self): - if self.enable_kv_cache_events: - events = self.tree_cache.take_events() - if events: - batch = KVEventBatch(ts=time.time(), events=events) - self.kv_event_publisher.publish(batch) + def current_scheduler_metrics_enabled(self): + return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers + + def maybe_sleep_on_idle(self): + if self.idle_sleeper is not None: + self.idle_sleeper.maybe_sleep() + + +class IdleSleeper: + """ + In setups which have long inactivity periods it is desirable to reduce + system power consumption when sglang does nothing. This would lead not only + to power savings, but also to more CPU thermal headroom when a request + eventually comes. This is important in cases when multiple GPUs are connected + as each GPU would otherwise pin one thread at 100% CPU usage. + + The simplest solution is to use zmq.Poller on all sockets that may receive + data that needs handling immediately. + """ + + def __init__(self, sockets): + self.poller = zmq.Poller() + for s in sockets: + self.poller.register(s, zmq.POLLIN) + + def maybe_sleep(self): + self.poller.poll(1000) def is_health_check_generate_req(recv_req): @@ -2931,20 +2348,6 @@ def is_work_request(recv_req): return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)) -def _export_static_state(model): - return dict( - buffers=[ - (name, buffer.detach().clone()) for name, buffer in model.named_buffers() - ] - ) - - -def _import_static_state(model, static_params): - self_named_buffers = dict(model.named_buffers()) - for name, tensor in static_params["buffers"]: - self_named_buffers[name][...] = tensor - - def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py new file mode 100644 index 000000000..a6497ffde --- /dev/null +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -0,0 +1,229 @@ +import logging +import time +from collections import defaultdict +from typing import List, Optional + +from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.managers.schedule_policy import PrefillAdder +from sglang.srt.managers.scheduler import Req, ScheduleBatch +from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats +from sglang.srt.utils import get_bool_env_var + +logger = logging.getLogger(__name__) + +RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") + + +class KvMetrics: + def __init__(self): + self.request_active_slots = None + self.request_total_slots = None + self.kv_active_blocks = None + self.kv_total_blocks = None + self.num_requests_waiting = None + self.gpu_cache_usage_perc = None + self.gpu_prefix_cache_hit_rate = None + self.data_parallel_rank = None + + +class SchedulerMetricsMixin: + def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): + self.last_gen_throughput: float = 0.0 + self.last_input_throughput: float = 0.0 + self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 + self.cum_spec_accept_length = 0 + self.cum_spec_accept_count = 0 + self.total_retracted_reqs = 0 + self.stats = SchedulerStats() + if self.enable_metrics: + engine_type = "unified" + labels = { + "model_name": self.server_args.served_model_name, + "engine_type": engine_type, + "tp_rank": tp_rank, + "pp_rank": pp_rank, + } + if dp_rank is not None: + labels["dp_rank"] = dp_rank + self.metrics_collector = SchedulerMetricsCollector(labels=labels) + + def init_kv_events(self, kv_events_config: Optional[str]): + if self.enable_kv_cache_events: + self.kv_event_publisher = EventPublisherFactory.create( + kv_events_config, self.attn_dp_rank + ) + + def log_prefill_stats( + self, + adder: PrefillAdder, + can_run_list: List[Req], + running_bs: int, + ): + gap_latency = time.perf_counter() - self.last_prefill_stats_tic + self.last_prefill_stats_tic = time.perf_counter() + self.last_input_throughput = self.last_prefill_tokens / gap_latency + self.last_prefill_tokens = adder.log_input_tokens + + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + _, + _, + _, + _, + ) = self._get_swa_token_info() + num_used = max(full_num_used, swa_num_used) + token_usage = max(full_token_usage, swa_token_usage) + token_msg = ( + f"full token usage: {full_token_usage:.2f}, " + f"swa token usage: {swa_token_usage:.2f}, " + ) + else: + num_used, token_usage, _, _ = self._get_token_info() + token_msg = f"token usage: {token_usage:.2f}, " + + num_new_seq = len(can_run_list) + f = ( + f"Prefill batch. " + f"#new-seq: {num_new_seq}, " + f"#new-token: {adder.log_input_tokens}, " + f"#cached-token: {adder.log_hit_tokens}, " + f"{token_msg}" + ) + + if self.disaggregation_mode == DisaggregationMode.PREFILL: + f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, " + f += f"#queue-req: {len(self.waiting_queue)}, " + f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, " + f += f"input throughput (token/s): {self.last_input_throughput:.2f}, " + else: + f += f"#running-req: {running_bs}, " + f += f"#queue-req: {len(self.waiting_queue)}, " + + logger.info(f) + + if self.enable_metrics: + total_tokens = adder.log_input_tokens + adder.log_hit_tokens + + cache_hit_rate = ( + adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0 + ) + self.stats.num_running_reqs = running_bs + self.stats.num_used_tokens = num_used + self.stats.token_usage = round(token_usage, 2) + self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.cache_hit_rate = cache_hit_rate + + total_queue_latency = 0 + for req in can_run_list: + total_queue_latency += req.queue_time_end - req.queue_time_start + self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq + + self.metrics_collector.log_stats(self.stats) + self._emit_kv_metrics() + self._publish_kv_events() + + def log_decode_stats( + self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None + ): + batch = running_batch or self.running_batch + + gap_latency = time.perf_counter() - self.last_decode_stats_tic + self.last_decode_stats_tic = time.perf_counter() + self.last_gen_throughput = self.num_generated_tokens / gap_latency + self.num_generated_tokens = 0 + num_running_reqs = len(batch.reqs) + if self.is_hybrid: + ( + full_num_used, + swa_num_used, + full_token_usage, + swa_token_usage, + _, + _, + _, + _, + ) = self._get_swa_token_info() + num_used = max(full_num_used, swa_num_used) + token_usage = max(full_token_usage, swa_token_usage) + token_msg = ( + f"#full token: {full_num_used}, " + f"full token usage: {full_token_usage:.2f}, " + f"#swa token: {swa_num_used}, " + f"swa token usage: {swa_token_usage:.2f}, " + ) + else: + num_used, token_usage, _, _ = self._get_token_info() + token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " + + if RECORD_STEP_TIME: + self.step_time_dict[num_running_reqs].append( + gap_latency / self.server_args.decode_log_interval + ) + + msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}" + + if self.spec_algorithm.is_none(): + spec_accept_length = 0 + else: + spec_accept_length = ( + self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct + ) + self.cum_spec_accept_length += self.spec_num_total_accepted_tokens + self.cum_spec_accept_count += self.spec_num_total_forward_ct + self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 + msg += f"accept len: {spec_accept_length:.2f}, " + + if self.disaggregation_mode == DisaggregationMode.DECODE: + msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " + msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " + + msg += ( + f"cuda graph: {can_run_cuda_graph}, " + f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}, " + ) + + logger.info(msg) + if self.enable_metrics: + self.stats.num_running_reqs = num_running_reqs + self.stats.num_used_tokens = num_used + self.stats.token_usage = round(token_usage, 2) + self.stats.cache_hit_rate = 0.0 + self.stats.gen_throughput = self.last_gen_throughput + self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_grammar_queue_reqs = len(self.grammar_queue) + self.stats.spec_accept_length = spec_accept_length + self.stats.total_retracted_reqs = self.total_retracted_reqs + self.metrics_collector.log_stats(self.stats) + self._emit_kv_metrics() + self._publish_kv_events() + + def _emit_kv_metrics(self): + kv_metrics = KvMetrics() + kv_metrics.request_active_slots = self.stats.num_running_reqs + kv_metrics.request_total_slots = self.max_running_requests + kv_metrics.kv_active_blocks = int( + self.stats.token_usage * self.max_total_num_tokens + ) + kv_metrics.kv_total_blocks = self.max_total_num_tokens + kv_metrics.num_requests_waiting = self.stats.num_queue_reqs + kv_metrics.gpu_cache_usage_perc = self.stats.token_usage + kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate + kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0 + + if not self.send_metrics_from_scheduler.closed: + self.send_metrics_from_scheduler.send_pyobj(kv_metrics) + + def _publish_kv_events(self): + if self.enable_kv_cache_events: + events = self.tree_cache.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py new file mode 100644 index 000000000..3d061a8fe --- /dev/null +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -0,0 +1,279 @@ +import logging +import os +import time +from pathlib import Path +from typing import List, Optional + +import torch + +from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType +from sglang.srt.model_executor.forward_batch_info import ForwardMode + +logger = logging.getLogger(__name__) + + +class SchedulerProfilerMixin: + + def init_profier(self): + self.torch_profiler = None + self.torch_profiler_output_dir: Optional[str] = None + self.profiler_activities: Optional[List[str]] = None + self.profile_id: Optional[str] = None + self.profiler_start_forward_ct: Optional[int] = None + self.profiler_target_forward_ct: Optional[int] = None + self.profiler_target_prefill_ct: Optional[int] = None + self.profiler_target_decode_ct: Optional[int] = None + self.profiler_prefill_ct: Optional[int] = None + self.profiler_decode_ct: Optional[int] = None + self.profile_by_stage: bool = False + self.profile_steps: Optional[int] = None + self.profile_in_progress: bool = False + self.rpd_profiler = None + + def init_profile( + self, + output_dir: Optional[str], + start_step: Optional[int], + num_steps: Optional[int], + activities: Optional[List[str]], + with_stack: Optional[bool], + record_shapes: Optional[bool], + profile_by_stage: bool, + profile_id: str, + ) -> ProfileReqOutput: + if self.profile_in_progress: + return ProfileReqOutput( + success=False, + message="Profiling is already in progress. Call /stop_profile first.", + ) + + self.profile_by_stage = profile_by_stage + + if output_dir is None: + output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp") + if activities is None: + activities = ["CPU", "GPU"] + + self.torch_profiler_output_dir = output_dir + self.torch_profiler_with_stack = with_stack + self.torch_profiler_record_shapes = record_shapes + self.profiler_activities = activities + self.profile_id = profile_id + + if start_step: + self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1) + + if num_steps: + self.profile_steps = num_steps + if self.profile_by_stage: + self.profiler_target_prefill_ct = num_steps + self.profiler_target_decode_ct = num_steps + self.profiler_prefill_ct = 0 + self.profiler_decode_ct = 0 + elif start_step: + self.profiler_target_forward_ct = ( + self.profiler_start_forward_ct + num_steps + ) + else: + self.profiler_target_forward_ct = self.forward_ct + num_steps + # The caller will be notified when reaching profiler_target_forward_ct + else: + self.profiler_target_forward_ct = None + + return ProfileReqOutput(success=True, message="Succeeded") + + def start_profile( + self, stage: Optional[ForwardMode] = None + ) -> ProfileReqOutput | None: + stage_str = f" for {stage.__str__()}" if stage else "" + logger.info( + f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})", + ) + + activities = self.profiler_activities + with_stack = self.torch_profiler_with_stack + record_shapes = self.torch_profiler_record_shapes + + activity_map = { + "CPU": torch.profiler.ProfilerActivity.CPU, + "GPU": torch.profiler.ProfilerActivity.CUDA, + } + torchprof_activities = [ + activity_map[a] for a in activities if a in activity_map + ] + + if "RPD" in activities: + from rpdTracerControl import rpdTracerControl + + rpdTracerControl.skipCreate() + + self.rpd_profile_path = os.path.join( + self.torch_profiler_output_dir, + "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz", + ) + + if self.tp_rank == 0: + import sqlite3 + + from rocpd.schema import RocpdSchema + + if os.path.exists("trace.rpd"): + os.unlink("trace.rpd") + schema = RocpdSchema() + connection = sqlite3.connect("trace.rpd") + schema.writeSchema(connection) + connection.commit() + del connection + torch.distributed.barrier(self.tp_cpu_group) + + self.rpd_profiler = rpdTracerControl() + self.rpd_profiler.setPythonTrace(True) + self.rpd_profiler.start() + self.rpd_profiler.rangePush("", "rpd profile range", "") + self.profile_in_progress = True + elif torchprof_activities: + self.torch_profiler = torch.profiler.profile( + activities=torchprof_activities, + with_stack=with_stack if with_stack is not None else True, + record_shapes=record_shapes if record_shapes is not None else False, + ) + self.torch_profiler.start() + self.profile_in_progress = True + + if "MEM" in activities: + torch.cuda.memory._record_memory_history(max_entries=100000) + self.profile_in_progress = True + + if "CUDA_PROFILER" in activities: + torch.cuda.cudart().cudaProfilerStart() + self.profile_in_progress = True + + return ProfileReqOutput(success=True, message="Succeeded") + + def stop_profile( + self, stage: Optional[ForwardMode] = None + ) -> ProfileReqOutput | None: + if not self.profile_in_progress: + return ProfileReqOutput( + success=False, + message="Profiling is not in progress. Call /start_profile first.", + ) + + if not Path(self.torch_profiler_output_dir).exists(): + Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True) + + stage_suffix = f"-{stage.__str__()}" if stage else "" + logger.info("Stop profiling" + stage_suffix + "...") + if self.torch_profiler is not None: + self.torch_profiler.stop() + self.torch_profiler.export_chrome_trace( + os.path.join( + self.torch_profiler_output_dir, + self.profile_id + + f"-TP-{self.tp_rank}" + + stage_suffix + + ".trace.json.gz", + ) + ) + torch.distributed.barrier(self.tp_cpu_group) + + if self.rpd_profiler is not None: + self.rpd_profiler.rangePop() + self.rpd_profiler.stop() + self.rpd_profiler.flush() + + torch.distributed.barrier(self.tp_cpu_group) + if self.tp_rank == 0: + from sglang.srt.utils import rpd_to_chrome_trace + + rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path) + self.rpd_profiler = None + self.rpd_profiler_path = None + + if self.profiler_activities is not None and "MEM" in self.profiler_activities: + memory_profile_path = os.path.join( + self.torch_profiler_output_dir, + str(time.time()) + + f"-TP-{self.tp_rank}-memory" + + stage_suffix + + ".pickle", + ) + torch.cuda.memory._dump_snapshot(memory_profile_path) + torch.cuda.memory._record_memory_history(enabled=None) + + if "CUDA_PROFILER" in self.profiler_activities: + torch.cuda.cudart().cudaProfilerStop() + + logger.info( + "Profiling done. Traces are saved to: %s", + self.torch_profiler_output_dir, + ) + self.torch_profiler = None + self.profile_in_progress = False + self.profiler_start_forward_ct = None + + return ProfileReqOutput(success=True, message="Succeeded.") + + def _profile_batch_predicate(self, batch): + if self.profile_by_stage: + if batch.forward_mode.is_prefill(): + if self.profiler_prefill_ct == 0: + self.start_profile(batch.forward_mode) + self.profiler_prefill_ct += 1 + if self.profiler_prefill_ct > self.profiler_target_prefill_ct: + if self.profile_in_progress: + self.stop_profile(stage=ForwardMode.EXTEND) + elif batch.forward_mode.is_decode(): + if self.profiler_decode_ct == 0: + if self.profile_in_progress: + # force trace flush + self.stop_profile(ForwardMode.EXTEND) + self.start_profile(batch.forward_mode) + self.profiler_decode_ct += 1 + if self.profiler_decode_ct > self.profiler_target_decode_ct: + if self.profile_in_progress: + self.stop_profile(stage=ForwardMode.DECODE) + elif batch.forward_mode.is_idle(): + pass + else: + raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}") + else: + # Check profiler + if ( + self.profiler_target_forward_ct + and self.profiler_target_forward_ct <= self.forward_ct + ): + self.stop_profile() + if ( + self.profiler_start_forward_ct + and self.profiler_start_forward_ct == self.forward_ct + ): + self.start_profile() + + def profile(self, recv_req: ProfileReq): + if recv_req.type == ProfileReqType.START_PROFILE: + if recv_req.profile_by_stage or recv_req.start_step: + return self.init_profile( + recv_req.output_dir, + recv_req.start_step, + recv_req.num_steps, + recv_req.activities, + recv_req.with_stack, + recv_req.record_shapes, + recv_req.profile_by_stage, + recv_req.profile_id, + ) + else: + self.init_profile( + recv_req.output_dir, + recv_req.start_step, + recv_req.num_steps, + recv_req.activities, + recv_req.with_stack, + recv_req.record_shapes, + recv_req.profile_by_stage, + recv_req.profile_id, + ) + return self.start_profile(True) + else: + return self.stop_profile() diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py new file mode 100644 index 000000000..eba92a2e0 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -0,0 +1,142 @@ +import logging +from typing import Tuple + +import torch + +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +from sglang.srt.managers.io_struct import ( + GetWeightsByNameReqInput, + GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, + UpdateWeightFromDiskReqInput, + UpdateWeightFromDiskReqOutput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, +) + +logger = logging.getLogger(__name__) + + +class SchedulerUpdateWeightsMixin: + + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + """In-place update of the weights from disk.""" + success, message = self.tp_worker.update_weights_from_disk(recv_req) + if success: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + return UpdateWeightFromDiskReqOutput(success, message, 0) + + def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): + """Initialize the online model parameter update group.""" + success, message = self.tp_worker.init_weights_update_group(recv_req) + return InitWeightsUpdateGroupReqOutput(success, message) + + def update_weights_from_distributed( + self, + recv_req: UpdateWeightsFromDistributedReqInput, + ) -> Tuple[bool, str]: + """Update the online model parameter.""" + success, message = self.tp_worker.update_weights_from_distributed(recv_req) + if success: + if recv_req.flush_cache: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + return UpdateWeightsFromDistributedReqOutput(success, message) + + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + """Update the online model parameter from tensors.""" + success, message = self.tp_worker.update_weights_from_tensor(recv_req) + # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later + if success: + if recv_req.flush_cache: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + torch.distributed.barrier(group=self.tp_cpu_group) + return UpdateWeightsFromTensorReqOutput(success, message) + + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + parameter = self.tp_worker.get_weights_by_name(recv_req) + return GetWeightsByNameReqOutput(parameter) + + def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): + tags = recv_req.tags + + if tags is None or len(tags) == 0: + tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] + + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.worker.model_runner.model + ) + torch.distributed.barrier(self.tp_cpu_group) + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS) + + return ReleaseMemoryOccupationReqOutput() + + def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): + tags = recv_req.tags + + if tags is None or len(tags) == 0: + tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] + + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) + torch.distributed.barrier(self.tp_cpu_group) + _import_static_state( + self.tp_worker.worker.model_runner.model, + self.stashed_model_static_state, + ) + del self.stashed_model_static_state + + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + + return ResumeMemoryOccupationReqOutput() + + def save_remote_model(self, params): + url = params["url"] + + worker = self.tp_worker.worker + + worker.model_runner.save_remote_model(url) + + def save_sharded_model(self, params): + worker = self.tp_worker.worker + + worker.model_runner.save_sharded_model( + path=params["path"], + pattern=params["pattern"], + max_size=params["max_size"], + ) + + +def _export_static_state(model): + return dict( + buffers=[ + (name, buffer.detach().clone()) for name, buffer in model.named_buffers() + ] + ) + + +def _import_static_state(model, static_params): + self_named_buffers = dict(model.named_buffers()) + for name, tensor in static_params["buffers"]: + self_named_buffers[name][...] = tensor diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 700e62ed4..9250c6866 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -170,16 +170,6 @@ class ReqState: output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) -def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: - is_cross_node = server_args.dist_init_addr - - if is_cross_node: - # Fallback to default CPU transport for multi-node - return "default" - else: - return "cuda_ipc" - - class TokenizerManager: """TokenizerManager is a process that tokenizes the text.""" @@ -199,16 +189,6 @@ class TokenizerManager: else None ) self.crash_dump_folder = server_args.crash_dump_folder - self.crash_dump_performed = False # Flag to ensure dump is only called once - - # Init inter-process communication - context = zmq.asyncio.Context(2) - self.recv_from_detokenizer = get_zmq_socket( - context, zmq.PULL, port_args.tokenizer_ipc_name, True - ) - self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name, True - ) # Read model args self.model_path = server_args.model_path @@ -218,8 +198,7 @@ class TokenizerManager: self.is_image_gen = self.model_config.is_image_gen self.context_len = self.model_config.context_len self.image_token_id = self.model_config.image_token_id - self._updating = False - self._cond = asyncio.Condition() + self.max_req_input_len = None # Will be set later in engine.py if self.model_config.is_multimodal: import_processors() @@ -258,39 +237,57 @@ class TokenizerManager: revision=server_args.revision, ) - # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. - # The registry dynamically updates as adapters are loaded / unloaded during runtime. It - # serves as the source of truth for available adapters and maps user-friendly LoRA names - # to internally used unique LoRA IDs. - self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) + # Init inter-process communication + context = zmq.asyncio.Context(2) + self.recv_from_detokenizer = get_zmq_socket( + context, zmq.PULL, port_args.tokenizer_ipc_name, True + ) + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True + ) - # Store states + # Request states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} + self.asyncio_tasks = set() + + # Health check self.health_check_failed = False self.gracefully_exit = False self.last_receive_tstamp = 0 + + # Dumping self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] - self.crash_dump_request_list: deque[Tuple] = deque() self.log_request_metadata = self.get_log_request_metadata() - self.session_futures = {} # session_id -> asyncio event - self.max_req_input_len = None - self.asyncio_tasks = set() + self.crash_dump_request_list: deque[Tuple] = deque() + self.crash_dump_performed = False # Flag to ensure dump is only called once + # Session + self.session_futures = {} # session_id -> asyncio event + + # Weight updates # The event to notify the weight sync is finished. self.model_update_lock = RWLock() self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( None ) + self._is_updating = False + self._is_updating_cond = asyncio.Condition() + # LoRA + # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. + # The registry dynamically updates as adapters are loaded / unloaded during runtime. It + # serves as the source of truth for available adapters and maps user-friendly LoRA names + # to internally used unique LoRA IDs. + self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) # Lock to serialize LoRA update operations. # Please note that, unlike `model_update_lock`, this does not block inference, allowing # LoRA updates and inference to overlap. self.lora_update_lock = asyncio.Lock() - # For pd disaggregtion + # For PD disaggregtion self.disaggregation_mode = DisaggregationMode( self.server_args.disaggregation_mode ) @@ -458,17 +455,11 @@ class TokenizerManager: request: Optional[fastapi.Request] = None, ): created_time = time.time() - async with self._cond: - await self._cond.wait_for(lambda: not self._updating) - self.auto_create_handle_loop() obj.normalize_batch_and_arguments() - if isinstance(obj, EmbeddingReqInput) and self.is_generation: - raise ValueError( - "This model does not appear to be an embedding model by default. " - "Please add `--is-embedding` when launching the server or try another model." - ) + async with self._is_updating_cond: + await self._is_updating_cond.wait_for(lambda: not self._is_updating) if self.log_requests: max_length, skip_names, _ = self.log_request_metadata @@ -567,6 +558,12 @@ class TokenizerManager: f"model's context length ({self.context_len} tokens)." ) + if isinstance(obj, EmbeddingReqInput) and self.is_generation: + raise ValueError( + "This model does not appear to be an embedding model by default. " + "Please add `--is-embedding` when launching the server or try another model." + ) + # Check total tokens (input + max_new_tokens) max_new_tokens = obj.sampling_params.get("max_new_tokens") if ( @@ -959,14 +956,14 @@ class TokenizerManager: await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) async def pause_generation(self): - async with self._cond: - self._updating = True + async with self._is_updating_cond: + self._is_updating = True self.abort_request(abort_all=True) async def continue_generation(self): - async with self._cond: - self._updating = False - self._cond.notify_all() + async with self._is_updating_cond: + self._is_updating = False + self._is_updating_cond.notify_all() async def update_weights_from_disk( self, @@ -1208,14 +1205,6 @@ class TokenizerManager: # Many DP ranks return [res.internal_state for res in responses] - async def get_load(self) -> dict: - # TODO(lsyin): fake load report server - if not self.current_load_lock.locked(): - async with self.current_load_lock: - internal_state = await self.get_internal_state() - self.current_load = internal_state[0]["load"] - return {"load": self.current_load} - async def set_internal_state( self, obj: SetInternalStateReq ) -> SetInternalStateReqOutput: @@ -1224,6 +1213,14 @@ class TokenizerManager: ) return [res.internal_state for res in responses] + async def get_load(self) -> dict: + # TODO(lsyin): fake load report server + if not self.current_load_lock.locked(): + async with self.current_load_lock: + internal_state = await self.get_internal_state() + self.current_load = internal_state[0]["load"] + return {"load": self.current_load} + def get_log_request_metadata(self): max_length = None skip_names = None @@ -1343,11 +1340,24 @@ class TokenizerManager: "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping." ) return - logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}") - self.crash_dump_performed = True + if not self.crash_dump_folder: return + logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}") + self.crash_dump_performed = True + + # Check if NFS directory is available + # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0] + # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access( + # expected_nfs_dir, os.W_OK + # ) + use_nfs_dir = False + if not use_nfs_dir: + logger.error( + f"Expected NFS directory is not available or writable. Uploading to GCS." + ) + data_to_dump = [] if self.crash_dump_request_list: data_to_dump.extend(self.crash_dump_request_list) @@ -1357,7 +1367,12 @@ class TokenizerManager: for rid, state in self.rid_to_state.items(): if not state.finished: unfinished_requests.append( - (state.obj, {}, state.created_time, time.time()) + ( + state.obj, + state.out_list[-1] if state.out_list else {}, + state.created_time, + time.time(), + ) ) if unfinished_requests: data_to_dump.extend(unfinished_requests) @@ -1365,10 +1380,11 @@ class TokenizerManager: if not data_to_dump: return + object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl' filename = os.path.join( self.crash_dump_folder, os.getenv("HOSTNAME", None), - f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl", + object_name, ) os.makedirs(os.path.dirname(filename), exist_ok=True) @@ -1383,6 +1399,24 @@ class TokenizerManager: f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}" ) + def _upload_file_to_gcs(bucket_name, source_file_path, object_name): + from google.cloud import storage + + client = storage.Client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(object_name) + blob.upload_from_filename(source_file_path, if_generation_match=0) + logger.error( + f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}" + ) + + if not use_nfs_dir: + _upload_file_to_gcs( + "sglang_crash_dump", + filename, + os.getenv("HOSTNAME", None) + "/" + object_name, + ) + async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(5) @@ -1426,7 +1460,7 @@ class TokenizerManager: while True: recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) - self.last_receive_tstamp = time.perf_counter() + self.last_receive_tstamp = time.time() def _handle_batch_output( self, @@ -1697,24 +1731,13 @@ class TokenizerManager: self.dump_requests_folder, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", ) - logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") - - to_dump = self.dump_request_list + self._dump_data_to_file( + data_list=self.dump_request_list, + filename=filename, + log_message=f"Dump {len(self.dump_request_list)} requests to {filename}", + ) self.dump_request_list = [] - to_dump_with_server_args = { - "server_args": self.server_args, - "requests": to_dump, - } - - def background_task(): - os.makedirs(self.dump_requests_folder, exist_ok=True) - with open(filename, "wb") as f: - pickle.dump(to_dump_with_server_args, f) - - # Schedule the task to run in the background without awaiting it - asyncio.create_task(asyncio.to_thread(background_task)) - def record_request_for_crash_dump(self, state: ReqState, out_dict: dict): current_time = time.time() self.crash_dump_request_list.append( @@ -1727,6 +1750,22 @@ class TokenizerManager: ): self.crash_dump_request_list.popleft() + def _dump_data_to_file( + self, data_list: List[Tuple], filename: str, log_message: str + ): + logger.info(log_message) + to_dump_with_server_args = { + "server_args": self.server_args, + "requests": data_list.copy(), + } + + def background_task(): + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(to_dump_with_server_args, f) + + asyncio.create_task(asyncio.to_thread(background_task)) + def _handle_abort_req(self, recv_obj): state = self.rid_to_state[recv_obj.rid] state.finished = True @@ -1862,6 +1901,16 @@ class TokenizerManager: return scores +def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: + is_cross_node = server_args.dist_init_addr + + if is_cross_node: + # Fallback to default CPU transport for multi-node + return "default" + else: + return "cuda_ipc" + + async def print_exception_wrapper(func): """ Sometimes an asyncio function does not print exception. diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dc0c6cd1a..856d68138 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2071,6 +2071,9 @@ class PortArgs: dist_init_host, dist_init_port = dist_init_addr port_base = int(dist_init_port) + 1 + detokenizer_port = port_base + 1 + rpc_port = port_base + 2 + metrics_ipc_name = port_base + 3 if dp_rank is None: # TokenizerManager to DataParallelController scheduler_input_port = port_base + 4 @@ -2080,10 +2083,10 @@ class PortArgs: return PortArgs( tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", - detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", + detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}", nccl_port=nccl_port, - rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}", - metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}", + rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}", + metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}", ) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index b7600b1a6..0ba6d46c3 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -291,17 +291,6 @@ def find_printable_text(text: str): return text[: text.rfind(" ") + 1] -def graceful_registry(sub_module_name: str): - def graceful_shutdown(signum, frame): - logger.info( - f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..." - ) - if signum == signal.SIGTERM: - logger.info(f"{sub_module_name} receive sigterm") - - signal.signal(signal.SIGTERM, graceful_shutdown) - - class LazyImport: """Lazy import to make `import sglang` run faster."""