Split the scheduler into multiple mixin classes to reduce the file size (#8483)

This commit is contained in:
Lianmin Zheng
2025-07-29 12:46:50 -07:00
committed by GitHub
parent 5973675bc3
commit a4c3b121d8
12 changed files with 869 additions and 785 deletions

View File

@@ -13,7 +13,6 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""
import datetime
import faulthandler
import logging
import os
@@ -21,11 +20,10 @@ import signal
import sys
import threading
import time
from collections import defaultdict, deque
from collections import deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
@@ -37,7 +35,6 @@ from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
create_grammar_backend,
@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput,
RpcReqOutput,
SetInternalStateReq,
@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.schedule_batch import (
@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy,
)
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_metrics_mixin import (
RECORD_STEP_TIME,
SchedulerMetricsMixin,
)
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu()
@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
bid: int
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class IdleSleeper:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def __init__(self, sockets):
self.poller = zmq.Poller()
for s in sockets:
self.poller.register(s, zmq.POLLIN)
def maybe_sleep(self):
self.poller.poll(1000)
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerUpdateWeightsMixin,
SchedulerProfilerMixin,
SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
):
@@ -266,7 +229,7 @@ class Scheduler(
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
@@ -284,10 +247,13 @@ class Scheduler(
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
@@ -299,9 +265,6 @@ class Scheduler(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
@@ -398,7 +361,7 @@ class Scheduler(
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Hybrid
# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
@@ -515,6 +478,15 @@ class Scheduler(
self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
@@ -545,22 +517,6 @@ class Scheduler(
]
)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
def current_scheduler_metrics_enabled(self):
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()
def init_tokenizer(self):
server_args = self.server_args
@@ -668,50 +624,6 @@ class Scheduler(
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
init_embedding_cache(embedding_cache_size * 1024 * 1024)
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_start_forward_ct: Optional[int] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
"tp_rank": tp_rank,
"pp_rank": pp_rank,
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def init_disaggregation(self):
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
@@ -820,10 +732,7 @@ class Scheduler(
self.process_batch_result(batch, result)
else:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
@@ -866,10 +775,7 @@ class Scheduler(
)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
@@ -1003,10 +909,8 @@ class Scheduler(
# When the server is idle, self-check and re-init some states
if server_is_idle:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
# When the server is idle, do self-check and re-init some states
self.self_check_during_idle()
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1355,170 +1259,11 @@ class Scheduler(
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
def _emit_kv_metrics(self):
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
kv_metrics.kv_active_blocks = int(
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
logger.info(f)
if self.enable_metrics:
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def self_check_during_idle(self):
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
def check_memory(self):
if self.is_hybrid:
@@ -2422,22 +2167,6 @@ class Scheduler(
barrier()
return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = []
@@ -2515,16 +2244,6 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput:
@@ -2541,81 +2260,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
@@ -2623,254 +2267,6 @@ class Scheduler(
self.forward_sleep_time = t
return SlowDownReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
if recv_req.profile_by_stage or recv_req.start_step:
return self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
return self.start_profile(True)
else:
return self.stop_profile()
def init_profile(
self,
output_dir: Optional[str],
start_step: Optional[int],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
self.profile_by_stage = profile_by_stage
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities
self.profile_id = profile_id
if start_step:
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
elif start_step:
self.profiler_target_forward_ct = (
self.profiler_start_forward_ct + num_steps
)
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
self.profile_in_progress = True
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
)
torch.distributed.barrier(self.tp_cpu_group)
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
elif batch.forward_mode.is_idle():
pass
else:
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
if (
self.profiler_start_forward_ct
and self.profiler_start_forward_ct == self.forward_ct
):
self.start_profile()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
get_global_expert_distribution_recorder().start_record()
@@ -2879,7 +2275,7 @@ class Scheduler(
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput):
@@ -2915,12 +2311,33 @@ class Scheduler(
prefix += f" PP{self.pp_rank}"
return prefix
def _publish_kv_events(self):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def current_scheduler_metrics_enabled(self):
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()
class IdleSleeper:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def __init__(self, sockets):
self.poller = zmq.Poller()
for s in sockets:
self.poller.register(s, zmq.POLLIN)
def maybe_sleep(self):
self.poller.poll(1000)
def is_health_check_generate_req(recv_req):
@@ -2931,20 +2348,6 @@ def is_work_request(recv_req):
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,