diff --git a/python/pyproject.toml b/python/pyproject.toml index eac224443..e708f00d4 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -14,18 +14,17 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", ] dependencies = [ - "aiohttp", - "requests", - "tqdm", - "numpy", "IPython", - "setproctitle", + "aiohttp", + "anthropic>=0.20.0", "blobfile==3.0.0", "build", "compressed-tensors", + "cuda-python", "datasets", "einops", "fastapi", + "flashinfer_python==0.4.0rc3", "hf_transfer", "huggingface_hub", "interegular", @@ -33,8 +32,10 @@ dependencies = [ "modelscope", "msgspec", "ninja", - "openai==1.99.1", + "numpy", + "nvidia-cutlass-dsl==4.2.1", "openai-harmony==0.0.4", + "openai==1.99.1", "orjson", "outlines==0.1.11", "packaging", @@ -42,32 +43,30 @@ dependencies = [ "pillow", "prometheus-client>=0.20.0", "psutil", + "py-spy", "pybase64", "pydantic", "pynvml", "python-multipart", "pyzmq>=25.1.2", + "requests", "scipy", "sentencepiece", + "setproctitle", + "sgl-kernel==0.3.13", "soundfile==0.13.1", - "timm==1.0.16", "tiktoken", + "timm==1.0.16", + "torch==2.8.0", + "torch_memory_saver==0.0.8", "torchao==0.9.0", + "torchaudio==2.8.0", + "torchvision", + "tqdm", "transformers==4.56.1", "uvicorn", "uvloop", - "xgrammar==0.1.24", - "sgl-kernel==0.3.13", - "torch==2.8.0", - "torchaudio==2.8.0", - "torchvision", - "cuda-python", - "flashinfer_python==0.4.0rc3", - "openai==1.99.1", - "tiktoken", - "anthropic>=0.20.0", - "torch_memory_saver==0.0.8", - "nvidia-cutlass-dsl==4.2.1", + "xgrammar==0.1.24" ] [project.optional-dependencies] @@ -79,15 +78,15 @@ test = [ "matplotlib", "pandas", "peft", - "sentence_transformers", "pytest", + "sentence_transformers", "tabulate", ] tracing = [ - "opentelemetry-sdk", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", ] all = ["sglang[test]", "sglang[decord]"] blackwell = ["sglang[test]", "sglang[decord]"] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 1db475f15..de26d351f 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server from __future__ import annotations import logging +import time from collections import deque from dataclasses import dataclass from http import HTTPStatus @@ -422,9 +423,13 @@ class DecodePreallocQueue: kv_indices, self.token_to_kv_pool_allocator.page_size ) decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) - decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) + preallocated_reqs.append(decode_req) indices_to_remove.add(i) + decode_req.req.time_stats.decode_transfer_queue_entry_time = ( + time.perf_counter() + ) + decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove @@ -625,6 +630,7 @@ class DecodeTransferQueue: decode_req.req.output_topk_p = output_topk_p decode_req.req.output_topk_index = output_topk_index decode_req.req.hidden_states_tensor = output_hidden_states + if decode_req.req.return_logprob: decode_req.req.output_token_logprobs_val.append( output_token_logprobs_val[0].item() @@ -645,10 +651,17 @@ class DecodeTransferQueue: if hasattr(decode_req.kv_receiver, "clear"): decode_req.kv_receiver.clear() + decode_req.kv_receiver = None + + indices_to_remove.add(i) + decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter() # special handling for sampling_params.max_new_tokens == 1 if decode_req.req.sampling_params.max_new_tokens == 1: # finish immediately + decode_req.req.time_stats.forward_entry_time = ( + decode_req.req.time_stats.completion_time + ) = time.perf_counter() decode_req.req.check_finished() self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob @@ -656,8 +669,6 @@ class DecodeTransferQueue: self.tree_cache.cache_finished_req(decode_req.req) else: transferred_reqs.append(decode_req.req) - - indices_to_remove.add(i) elif poll in [ KVPoll.Bootstrapping, KVPoll.WaitingForInput, @@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin: if len(can_run_list) == 0: return None + for req in can_run_list: + req.time_stats.forward_entry_time = time.perf_counter() + # construct a schedule batch with those requests and mark as decode new_batch = ScheduleBatch.init_new( can_run_list, diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 3f794ea3a..f31c5eeea 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -21,6 +21,7 @@ from __future__ import annotations import logging import threading +import time from collections import deque from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Type @@ -263,9 +264,10 @@ class PrefillBootstrapQueue: num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) - req.add_latency(RequestStage.PREFILL_BOOTSTRAP) bootstrapped_reqs.append(req) indices_to_remove.add(i) + req.time_stats.wait_queue_entry_time = time.perf_counter() + req.add_latency(RequestStage.PREFILL_BOOTSTRAP) self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove @@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin: for i, (req, next_token_id) in enumerate( zip(batch.reqs, next_token_ids, strict=True) ): - req: Req if req.is_chunked <= 0: # There is no output_ids for prefill req.output_ids.append(next_token_id) @@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin: ) logprob_pt += num_input_logprobs self.send_kv_chunk(req, last_chunk=True) + req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter() if req.grammar is not None: # FIXME: this try-except block is for handling unexpected xgrammar issue. @@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin: else: assert False, f"Unexpected polling state {poll=}" + for req in done_reqs: + req.time_stats.completion_time = time.perf_counter() + # Stream requests which have finished transfer self.stream_output( done_reqs, diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index fe4e7fb9f..d660172de 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -5,7 +5,7 @@ import random from collections import deque from contextlib import nullcontext from enum import Enum -from typing import TYPE_CHECKING, List, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Type import numpy as np import torch diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index adabae9d7..32df8e26e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -41,7 +41,7 @@ import time from enum import Enum, auto from http import HTTPStatus from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, @@ -452,6 +453,7 @@ class Req: bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, bootstrap_room: Optional[int] = None, + disagg_mode: Optional[DisaggregationMode] = None, data_parallel_rank: Optional[int] = None, vocab_size: Optional[int] = None, priority: Optional[int] = None, @@ -628,10 +630,8 @@ class Req: # For metrics self.metrics_collector = metrics_collector - self.time_stats: TimeStats = TimeStats() + self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) self.has_log_time_stats: bool = False - self.queue_time_start = None - self.queue_time_end = None self.last_tic = time.monotonic() # For disaggregation @@ -668,9 +668,9 @@ class Req: def add_latency(self, stage: RequestStage): if self.metrics_collector is None: return - assert stage.name in RequestStage.__members__, f"{stage=} is invalid" + now = time.monotonic() - self.metrics_collector.observe_request_latency_seconds( + self.metrics_collector.observe_per_stage_req_latency( stage.value, now - self.last_tic ) self.last_tic = now @@ -834,10 +834,10 @@ class Req: return if self.bootstrap_room is not None: - prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" + prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})" else: - prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" - logger.info(f"{prefix}: {self.time_stats}") + prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})" + logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}") self.has_log_time_stats = True def set_finish_with_abort(self, error_msg: str): @@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) / total_max_new_tokens new_estimate_ratio = min(1.0, new_estimate_ratio) - return retracted_reqs, new_estimate_ratio + return retracted_reqs, new_estimate_ratio, [] def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): req = self.reqs[idx] diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 755ac29c8..60633552b 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -276,9 +276,13 @@ class SchedulePolicy: ) -> None: """Sorts the waiting queue based on the request priority then received titmestamp.""" if schedule_low_priority_values_first: - waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start)) + waiting_queue.sort( + key=lambda x: (x.priority, x.time_stats.wait_queue_entry_time) + ) else: - waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start)) + waiting_queue.sort( + key=lambda x: (-x.priority, x.time_stats.wait_queue_entry_time) + ) @staticmethod def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: @@ -642,12 +646,12 @@ class PrefillAdder: if server_args.schedule_low_priority_values_first: sorted_running_reqs = sorted( self.running_batch.reqs, - key=lambda x: (-x.priority, -x.queue_time_start), + key=lambda x: (-x.priority, -x.time_stats.wait_queue_entry_time), ) else: sorted_running_reqs = sorted( self.running_batch.reqs, - key=lambda x: (x.priority, -x.queue_time_start), + key=lambda x: (x.priority, -x.time_stats.wait_queue_entry_time), ) preemptible_reqs = [] min_tokens_to_remove = ( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d62c7f01c..c71e937f7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.tracing.trace import ( process_tracing_init, - trace_event, trace_set_proc_propagate_context, trace_set_thread_info, - trace_slice, + trace_slice_batch, trace_slice_end, trace_slice_start, ) @@ -263,6 +262,7 @@ class Scheduler( server_args.enable_metrics_for_all_schedulers ) self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0 + self.enable_trace = server_args.enable_trace self.stream_interval = server_args.stream_interval self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm @@ -899,10 +899,6 @@ class Scheduler( batch = self.get_next_batch_to_run() self.cur_batch = batch - if batch: - for req in batch.reqs: - trace_event("schedule", req.rid) - if batch: result = self.run_batch(batch) self.process_batch_result(batch, result) @@ -924,10 +920,6 @@ class Scheduler( batch = self.get_next_batch_to_run() self.cur_batch = batch - if batch: - for req in batch.reqs: - trace_event("schedule", req.rid) - if batch: batch.launch_done = threading.Event() result = self.run_batch(batch) @@ -1192,10 +1184,13 @@ class Scheduler( src=self.tp_group.ranks[0], ) - for req in recv_reqs: - if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)): - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start("", req.rid, anonymous=True) + if self.enable_trace: + for req in recv_reqs: + if isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ): + trace_set_proc_propagate_context(req.rid, req.trace_context) + trace_slice_start("", req.rid, anonymous=True) return recv_reqs @@ -1277,6 +1272,7 @@ class Scheduler( bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, bootstrap_room=recv_req.bootstrap_room, + disagg_mode=self.disaggregation_mode, data_parallel_rank=recv_req.data_parallel_rank, vocab_size=self.model_config.vocab_size, priority=recv_req.priority, @@ -1403,7 +1399,6 @@ class Scheduler( req.set_finish_with_abort(error_msg) if add_to_grammar_queue: - req.queue_time_start = time.perf_counter() self.grammar_queue.append(req) else: self._add_request_to_queue(req) @@ -1419,23 +1414,6 @@ class Scheduler( for tokenized_req in recv_req: self.handle_generate_request(tokenized_req) - def _add_request_to_queue(self, req: Req): - req.queue_time_start = time.perf_counter() - if self.disaggregation_mode == DisaggregationMode.PREFILL: - self._prefetch_kvcache(req) - self.disagg_prefill_bootstrap_queue.add( - req, self.model_config.num_key_value_heads - ) - elif self.disaggregation_mode == DisaggregationMode.DECODE: - self.disagg_decode_prealloc_queue.add(req) - else: - self._set_or_validate_priority(req) - if self._abort_on_queued_limit(req): - return - self._prefetch_kvcache(req) - self.waiting_queue.append(req) - trace_slice_end("process req", req.rid, auto_next_anon=True) - def _prefetch_kvcache(self, req: Req): if self.enable_hicache_storage: req.init_next_round_input(self.tree_cache) @@ -1449,19 +1427,27 @@ class Scheduler( req.rid, req.last_host_node, new_input_tokens, last_hash ) - def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): - if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.disagg_prefill_bootstrap_queue.extend( - reqs, self.model_config.num_key_value_heads + def _add_request_to_queue(self, req: Req, is_retracted: bool = False): + if self.disaggregation_mode == DisaggregationMode.NULL: + self._set_or_validate_priority(req) + if self._abort_on_queued_limit(req): + return + self._prefetch_kvcache(req) + self.waiting_queue.append(req) + req.time_stats.wait_queue_entry_time = time.perf_counter() + trace_slice_end("process req", req.rid, auto_next_anon=True) + elif self.disaggregation_mode == DisaggregationMode.PREFILL: + self._prefetch_kvcache(req) + self.disagg_prefill_bootstrap_queue.add( + req, self.model_config.num_key_value_heads ) + req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter() elif self.disaggregation_mode == DisaggregationMode.DECODE: - # If this is a decode server, we put the request to the decode pending prealloc queue - self.disagg_decode_prealloc_queue.extend(reqs, is_retracted) + self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted) + if not is_retracted: + req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter() else: - for req in reqs: - self._set_or_validate_priority(req) - if not self._abort_on_queued_limit(req): - self.waiting_queue.append(req) + raise ValueError(f"Invalid {self.disaggregation_mode=}") def _set_or_validate_priority(self, req: Req): """Set the default priority value, or abort the request based on the priority scheduling mode.""" @@ -1500,7 +1486,7 @@ class Scheduler( direction = 1 if self.schedule_low_priority_values_first else -1 key_fn = lambda item: ( direction * item[1].priority, - item[1].queue_time_start, + item[1].time_stats.wait_queue_entry_time, ) idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn) abort_existing_req = ( @@ -1902,14 +1888,14 @@ class Scheduler( if self.enable_metrics: # only record queue time when enable_metrics is True to avoid overhead for req in can_run_list: - req.queue_time_end = time.perf_counter() req.add_latency(RequestStage.PREFILL_WAITING) self.waiting_queue = [ x for x in self.waiting_queue if x not in set(can_run_list) ] if adder.preempt_list: - self._extend_requests_to_queue(adder.preempt_list) + for req in adder.preempt_list: + self._add_request_to_queue(req) if adder.new_chunked_req is not None: assert self.chunked_req is None @@ -1920,7 +1906,16 @@ class Scheduler( # Print stats if self.current_scheduler_metrics_enabled(): - self.log_prefill_stats(adder, can_run_list, running_bs) + self.log_prefill_stats(adder, can_run_list, running_bs, 0) + + for req in can_run_list: + if req.time_stats.forward_entry_time == 0: + # Avoid update chunked request many times + req.time_stats.forward_entry_time = time.perf_counter() + if self.enable_metrics: + self.metrics_collector.observe_queue_time( + req.time_stats.get_queueing_time(), + ) # Create a new batch new_batch = ScheduleBatch.init_new( @@ -1975,19 +1970,25 @@ class Scheduler( TEST_RETRACT and batch.batch_size() > 10 ): old_ratio = self.new_token_ratio - - retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) - num_retracted_reqs = len(retracted_reqs) + retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode( + self.server_args + ) + self.num_retracted_reqs = len(retracted_reqs) self.new_token_ratio = new_token_ratio + for req in reqs_to_abort: + self.send_to_tokenizer.send_pyobj( + AbortReq(req.rid, abort_reason=req.to_abort_message) + ) logger.info( "KV cache pool is full. Retract requests. " - f"#retracted_reqs: {num_retracted_reqs}, " - f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" + f"#retracted_reqs: {len(retracted_reqs)}, " + f"#aborted_retracted_reqs: {len(reqs_to_abort)}, " + f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}" ) - self._extend_requests_to_queue(retracted_reqs, is_retracted=True) - self.total_retracted_reqs += num_retracted_reqs + for req in retracted_reqs: + self._add_request_to_queue(req, is_retracted=True) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, @@ -2086,23 +2087,14 @@ class Scheduler( ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result, launch_done) - for req in batch.reqs: - trace_slice( - "decode loop", - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) + if self.enable_trace: + trace_slice_batch("decode loop", batch.reqs) elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result, launch_done) - for req in batch.reqs: - trace_slice( - "prefill", - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) + if self.enable_trace: + trace_slice_batch("prefill", batch.reqs) + elif batch.forward_mode.is_idle(): if self.enable_overlap: self.tp_worker.resolve_last_batch_result(launch_done) @@ -2261,12 +2253,13 @@ class Scheduler( if req.finished(): # It is aborted by AbortReq num_ready_reqs += 1 continue + req.grammar = req.grammar.result(timeout=0.03) self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) if req.grammar is INVALID_GRAMMAR_OBJ: - req.set_finish_with_abort( - f"Invalid grammar request: {req.grammar_key=}" - ) + error_msg = f"Invalid grammar request: {req.grammar_key=}" + req.set_finish_with_abort(error_msg) + num_ready_reqs += 1 except futures._base.TimeoutError: req.grammar_wait_ct += 1 @@ -2298,9 +2291,8 @@ class Scheduler( req.grammar = req.grammar.result() self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) if req.grammar is INVALID_GRAMMAR_OBJ: - req.set_finish_with_abort( - f"Invalid grammar request: {req.grammar_key=}" - ) + error_msg = f"Invalid grammar request: {req.grammar_key=}" + req.set_finish_with_abort(error_msg) else: num_ready_reqs_max = num_ready_reqs num_timeout_reqs_max = num_timeout_reqs @@ -2308,12 +2300,14 @@ class Scheduler( for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max): req = self.grammar_queue[i] req.grammar.cancel() + self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ) error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}" req.set_finish_with_abort(error_msg) - self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ) + num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max - self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) + for req in self.grammar_queue[:num_ready_reqs]: + self._add_request_to_queue(req) self.grammar_queue = self.grammar_queue[num_ready_reqs:] def set_next_batch_sampling_info_done(self, batch: ScheduleBatch): @@ -2795,17 +2789,11 @@ def run_scheduler_process( pipe_writer, balance_meta: Optional[DPBalanceMeta] = None, ): - if server_args.enable_trace: - process_tracing_init(server_args.oltp_traces_endpoint, "sglang") - if server_args.disaggregation_mode == "null": - thread_label = "Scheduler" - trace_set_thread_info(thread_label, tp_rank, dp_rank) - - if (numa_node := server_args.numa_node) is not None: - numa_bind_to_node(numa_node[gpu_id]) - - # Generate the prefix + # Generate the logger prefix prefix = "" + if dp_rank is None and "SGLANG_DP_RANK" in os.environ: + # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var + dp_rank = int(os.environ["SGLANG_DP_RANK"]) if dp_rank is not None: prefix += f" DP{dp_rank}" if server_args.tp_size > 1: @@ -2821,10 +2809,6 @@ def run_scheduler_process( kill_itself_when_parent_died() parent_process = psutil.Process().parent() - # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var - if dp_rank is None and "SGLANG_DP_RANK" in os.environ: - dp_rank = int(os.environ["SGLANG_DP_RANK"]) - # Configure the logger configure_logger(server_args, prefix=prefix) suppress_other_loggers() @@ -2832,6 +2816,15 @@ def run_scheduler_process( # Set cpu affinity to this gpu process if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + if (numa_node := server_args.numa_node) is not None: + numa_bind_to_node(numa_node[gpu_id]) + + # Set up tracing + if server_args.enable_trace: + process_tracing_init(server_args.oltp_traces_endpoint, "sglang") + if server_args.disaggregation_mode == "null": + thread_label = "Scheduler" + trace_set_thread_info(thread_label, tp_rank, dp_rank) # Create a scheduler and run the event loop try: diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 66cdc95bb..5966925df 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -47,8 +47,11 @@ class SchedulerMetricsMixin: self.spec_num_total_forward_ct = 0 self.cum_spec_accept_length = 0 self.cum_spec_accept_count = 0 - self.total_retracted_reqs = 0 + self.kv_transfer_speed_gb_s: float = 0.0 + self.kv_transfer_latency_ms: float = 0.0 + self.stats = SchedulerStats() + if self.enable_metrics: engine_type = "unified" labels = { @@ -82,12 +85,14 @@ class SchedulerMetricsMixin: adder: PrefillAdder, can_run_list: List[Req], running_bs: int, + running_bs_offline_batch: int, ): gap_latency = time.perf_counter() - self.last_prefill_stats_tic self.last_prefill_stats_tic = time.perf_counter() self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_prefill_tokens = adder.log_input_tokens + # TODO: generalize this for various memory pools if self.is_hybrid: ( full_num_used, @@ -101,51 +106,53 @@ class SchedulerMetricsMixin: ) = self._get_swa_token_info() num_used = max(full_num_used, swa_num_used) token_usage = max(full_token_usage, swa_token_usage) - token_msg = ( + token_usage_msg = ( f"full token usage: {full_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, " ) else: num_used, token_usage, _, _ = self._get_token_info() - token_msg = f"token usage: {token_usage:.2f}, " + token_usage_msg = f"token usage: {token_usage:.2f}, " - num_new_seq = len(can_run_list) f = ( f"Prefill batch. " - f"#new-seq: {num_new_seq}, " + f"#new-seq: {len(can_run_list)}, " f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " - f"{token_msg}" + f"{token_usage_msg}" + f"#running-req: {running_bs}, " + f"#queue-req: {len(self.waiting_queue)}, " ) if self.disaggregation_mode == DisaggregationMode.PREFILL: - f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, " - f += f"#queue-req: {len(self.waiting_queue)}, " - f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, " - f += f"input throughput (token/s): {self.last_input_throughput:.2f}, " - else: - f += f"#running-req: {running_bs}, " - f += f"#queue-req: {len(self.waiting_queue)}, " + f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, " + f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, " logger.info(f) if self.enable_metrics: + # Basics total_tokens = adder.log_input_tokens + adder.log_hit_tokens - cache_hit_rate = ( adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0 ) + self.stats.num_running_reqs = running_bs + self.stats.num_running_reqs_offline_batch = running_bs_offline_batch self.stats.num_used_tokens = num_used - self.stats.token_usage = round(token_usage, 2) + self.stats.token_usage = token_usage + if self.is_hybrid: + self.stats.swa_token_usage = swa_token_usage self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.cache_hit_rate = cache_hit_rate - total_queue_latency = 0 - for req in can_run_list: - total_queue_latency += req.queue_time_end - req.queue_time_start - self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq + # Retract + self.stats.num_retracted_reqs = self.num_retracted_reqs + self.stats.num_paused_reqs = self.num_paused_reqs + self.num_retracted_reqs = self.num_paused_reqs = 0 + # PD disaggregation if self.disaggregation_mode == DisaggregationMode.PREFILL: self.stats.num_prefill_prealloc_queue_reqs = len( self.disagg_prefill_bootstrap_queue.queue @@ -153,7 +160,18 @@ class SchedulerMetricsMixin: self.stats.num_prefill_inflight_queue_reqs = len( self.disagg_prefill_inflight_queue ) + self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s + self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms + elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.stats.num_decode_prealloc_queue_reqs = len( + self.disagg_decode_prealloc_queue.queue + ) + self.stats.num_decode_transfer_queue_reqs = len( + self.disagg_decode_transfer_queue.queue + ) + # Others + self.calculate_utilization() self.metrics_collector.log_stats(self.stats) self._emit_kv_metrics() self._publish_kv_events() @@ -166,8 +184,12 @@ class SchedulerMetricsMixin: gap_latency = time.perf_counter() - self.last_decode_stats_tic self.last_decode_stats_tic = time.perf_counter() self.last_gen_throughput = self.num_generated_tokens / gap_latency + self.num_generated_tokens = 0 num_running_reqs = len(batch.reqs) + num_running_reqs_offline_batch = 0 + + # TODO: generalize this for various memory pools if self.is_hybrid: ( full_num_used, @@ -181,7 +203,7 @@ class SchedulerMetricsMixin: ) = self._get_swa_token_info() num_used = max(full_num_used, swa_num_used) token_usage = max(full_token_usage, swa_token_usage) - token_msg = ( + token_usage_msg = ( f"#full token: {full_num_used}, " f"full token usage: {full_token_usage:.2f}, " f"#swa token: {swa_num_used}, " @@ -189,14 +211,14 @@ class SchedulerMetricsMixin: ) else: num_used, token_usage, _, _ = self._get_token_info() - token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " + token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, " if RECORD_STEP_TIME: self.step_time_dict[num_running_reqs].append( gap_latency / self.server_args.decode_log_interval ) - msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}" + msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}" if self.spec_algorithm.is_none(): spec_accept_length = 0 @@ -208,41 +230,66 @@ class SchedulerMetricsMixin: self.cum_spec_accept_count += self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 msg += f"accept len: {spec_accept_length:.2f}, " + cache_hit_rate = 0.0 if self.disaggregation_mode == DisaggregationMode.DECODE: msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " + msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, " + msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, " msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += ( - f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, " + f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}, " ) logger.info(msg) if self.enable_metrics: + # Basics self.stats.num_running_reqs = num_running_reqs + self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch self.stats.num_used_tokens = num_used - self.stats.token_usage = round(token_usage, 2) - self.stats.cache_hit_rate = 0.0 + self.stats.token_usage = token_usage + if self.is_hybrid: + self.stats.swa_token_usage = swa_token_usage self.stats.gen_throughput = self.last_gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) + self.stats.cache_hit_rate = cache_hit_rate self.stats.spec_accept_length = spec_accept_length - self.stats.total_retracted_reqs = self.total_retracted_reqs - self.stats.avg_request_queue_latency = 0.0 - if self.disaggregation_mode == DisaggregationMode.DECODE: + + # Retract + self.stats.num_retracted_reqs = self.num_retracted_reqs + self.stats.num_paused_reqs = self.num_paused_reqs + self.num_retracted_reqs = self.num_paused_reqs = 0 + + # PD disaggregation + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.stats.num_prefill_prealloc_queue_reqs = len( + self.disagg_prefill_bootstrap_queue.queue + ) + self.stats.num_prefill_inflight_queue_reqs = len( + self.disagg_prefill_inflight_queue + ) + elif self.disaggregation_mode == DisaggregationMode.DECODE: self.stats.num_decode_prealloc_queue_reqs = len( self.disagg_decode_prealloc_queue.queue ) self.stats.num_decode_transfer_queue_reqs = len( self.disagg_decode_transfer_queue.queue ) + + # Others + self.calculate_utilization() self.metrics_collector.log_stats(self.stats) self._emit_kv_metrics() self._publish_kv_events() def _emit_kv_metrics(self: Scheduler): + if not self.enable_kv_cache_events: + return + kv_metrics = KvMetrics() kv_metrics.request_active_slots = self.stats.num_running_reqs kv_metrics.request_total_slots = self.max_running_requests @@ -259,11 +306,13 @@ class SchedulerMetricsMixin: self.send_metrics_from_scheduler.send_pyobj(kv_metrics) def _publish_kv_events(self: Scheduler): - if self.enable_kv_cache_events: - events = self.tree_cache.take_events() - if events: - batch = KVEventBatch(ts=time.time(), events=events) - self.kv_event_publisher.publish(batch) + if not self.enable_kv_cache_events: + return + + events = self.tree_cache.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) def maybe_update_dp_balance_data( self: Scheduler, recv_req: TokenizedGenerateReqInput @@ -349,3 +398,17 @@ class SchedulerMetricsMixin: # 2. Atomically write local_tokens and onfly into shm under the mutex meta.set_shared_onfly_info(onfly_list) meta.set_shared_local_tokens(local_tokens) + + def calculate_utilization(self): + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.stats.utilization = -1 + else: + if ( + self.stats.max_running_requests_under_SLO is not None + and self.stats.max_running_requests_under_SLO > 0 + ): + self.stats.utilization = max( + self.stats.num_running_reqs + / self.stats.max_running_requests_under_SLO, + self.stats.token_usage / 0.9, + ) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 5d8545dac..750de6689 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin: if req.finished(): self.tree_cache.cache_finished_req(req) - req.time_stats.completion_time = time.time() + req.time_stats.completion_time = time.perf_counter() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) @@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin: else: self.tree_cache.cache_finished_req(req) - req.time_stats.completion_time = time.time() + req.time_stats.completion_time = time.perf_counter() if req.return_logprob and batch.spec_algorithm.is_none(): # speculative worker handles logprob in speculative decoding @@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin: and self.tp_rank == 0 and self.server_args.enable_request_time_stats_logging ): + print(f"{req.finished_reason=}") req.log_time_stats() # Send to detokenizer diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 1c541914c..c8df235cb 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -5,6 +5,7 @@ import copy import logging import os import time +import uuid from collections import deque from typing import ( TYPE_CHECKING, @@ -24,6 +25,7 @@ import zmq from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, ClearHiCacheReqOutput, + CloseSessionReqInput, DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqOutput, ExpertDistributionReq, @@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import ( LoadLoRAAdapterReqOutput, LoRAUpdateResult, MultiTokenizerWrapper, + OpenSessionReqInput, ProfileReq, ProfileReqOutput, ProfileReqType, @@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin: async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]: req = GetLoadReqInput() return await self.get_load_communicator(req) + + async def open_session( + self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None + ): + self.auto_create_handle_loop() + + if obj.session_id is None: + obj.session_id = uuid.uuid4().hex + elif obj.session_id in self.session_futures: + return None + + if self.server_args.tokenizer_worker_num > 1: + obj = MultiTokenizerWrapper(self.worker_id, obj) + self.send_to_scheduler.send_pyobj(obj) + + self.session_futures[obj.session_id] = asyncio.Future() + session_id = await self.session_futures[obj.session_id] + del self.session_futures[obj.session_id] + return session_id + + async def close_session( + self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None + ): + await self.send_to_scheduler.send_pyobj(obj) + + def get_log_request_metadata(self): + max_length = None + skip_names = None + out_skip_names = None + if self.log_requests: + if self.log_requests_level == 0: + max_length = 1 << 30 + skip_names = set( + [ + "text", + "input_ids", + "input_embeds", + "image_data", + "audio_data", + "lora_path", + "sampling_params", + ] + ) + out_skip_names = set( + [ + "text", + "output_ids", + "embedding", + ] + ) + elif self.log_requests_level == 1: + max_length = 1 << 30 + skip_names = set( + [ + "text", + "input_ids", + "input_embeds", + "image_data", + "audio_data", + "lora_path", + ] + ) + out_skip_names = set( + [ + "text", + "output_ids", + "embedding", + ] + ) + elif self.log_requests_level == 2: + max_length = 2048 + elif self.log_requests_level == 3: + max_length = 1 << 30 + else: + raise ValueError( + f"Invalid --log-requests-level: {self.log_requests_level=}" + ) + return max_length, skip_names, out_skip_names diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cc4b8c038..65fccb1dc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): else None ) self.crash_dump_folder = server_args.crash_dump_folder + self.enable_trace = server_args.enable_trace # Read model args self.model_path = server_args.model_path @@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): # If it's a single value, add worker_id prefix obj.rid = f"{self.worker_id}_{obj.rid}" - if obj.is_single: - bootstrap_room = ( - obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None - ) - trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9)) - trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) - else: - for i in range(len(obj.rid)): - bootstrap_room = ( - obj.bootstrap_room[i] - if hasattr(obj, "bootstrap_room") and obj.bootstrap_room - else None - ) - trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9)) - trace_slice_start( - "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True - ) + if self.enable_trace: + self._trace_request_start(obj, created_time) if self.log_requests: max_length, skip_names, _ = self.log_request_metadata @@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): req = AbortReq(rid, abort_all) self.send_to_scheduler.send_pyobj(req) if self.enable_metrics: - self.metrics_collector.observe_one_aborted_request() + # TODO: also use custom_labels from the request + self.metrics_collector.observe_one_aborted_request( + self.metrics_collector.labels + ) async def pause_generation(self): async with self.is_pause_cond: @@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): all_paused_requests = [r.num_paused_requests for r in result] return all_success, all_message, all_paused_requests - async def open_session( - self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None - ): - self.auto_create_handle_loop() - - if obj.session_id is None: - obj.session_id = uuid.uuid4().hex - elif obj.session_id in self.session_futures: - return None - - if self.server_args.tokenizer_worker_num > 1: - obj = MultiTokenizerWrapper(self.worker_id, obj) - self.send_to_scheduler.send_pyobj(obj) - - self.session_futures[obj.session_id] = asyncio.Future() - session_id = await self.session_futures[obj.session_id] - del self.session_futures[obj.session_id] - return session_id - - async def close_session( - self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None - ): - await self.send_to_scheduler.send_pyobj(obj) - - def get_log_request_metadata(self): - max_length = None - skip_names = None - out_skip_names = None - if self.log_requests: - if self.log_requests_level == 0: - max_length = 1 << 30 - skip_names = set( - [ - "text", - "input_ids", - "input_embeds", - "image_data", - "audio_data", - "lora_path", - "sampling_params", - ] - ) - out_skip_names = set( - [ - "text", - "output_ids", - "embedding", - ] - ) - elif self.log_requests_level == 1: - max_length = 1 << 30 - skip_names = set( - [ - "text", - "input_ids", - "input_embeds", - "image_data", - "audio_data", - "lora_path", - ] - ) - out_skip_names = set( - [ - "text", - "output_ids", - "embedding", - ] - ) - elif self.log_requests_level == 2: - max_length = 2048 - elif self.log_requests_level == 3: - max_length = 1 << 30 - else: - raise ValueError( - f"Invalid --log-requests-level: {self.log_requests_level=}" - ) - return max_length, skip_names, out_skip_names - def configure_logging(self, obj: ConfigureLoggingReq): if obj.log_requests is not None: self.log_requests = obj.log_requests @@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): # Drain requests while True: remain_num_req = len(self.rid_to_state) + remaining_rids = list(self.rid_to_state.keys()) if self.server_status == ServerStatus.UnHealthy: # if health check failed, we should exit immediately logger.error( - "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", - remain_num_req, + "Signal SIGTERM received while health check failed. Force exiting." ) self.dump_requests_before_crash() break @@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): elif get_bool_env_var("SGL_FORCE_SHUTDOWN"): # if force shutdown flag set, exit immediately logger.error( - "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d", - remain_num_req, + "Signal SIGTERM received while force shutdown flag set. Force exiting." ) break logger.info( - f"Gracefully exiting... remaining number of requests {remain_num_req}" + f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}." ) if remain_num_req > 0: await asyncio.sleep(5) @@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin): load_udpate_req = WatchLoadUpdateReq(loads=loads) self.send_to_scheduler.send_pyobj(load_udpate_req) + def _trace_request_start( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + created_time: Optional[float] = None, + ): + if obj.is_single: + bootstrap_room = ( + obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None + ) + trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9)) + trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) + else: + for i in range(len(obj.rid)): + bootstrap_room = ( + obj.bootstrap_room[i] + if hasattr(obj, "bootstrap_room") and obj.bootstrap_room + else None + ) + trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9)) + trace_slice_start( + "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True + ) + class ServerStatus(Enum): Up = "Up" @@ -1933,7 +1866,7 @@ class SignalHandler: def running_phase_sigquit_handler(self, signum=None, frame=None): logger.error( - "Received sigquit from a child process. It usually means the child failed." + f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed." ) self.tokenizer_manager.dump_requests_before_crash() kill_process_tree(os.getpid()) diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 884f4e211..01ebc8063 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -14,9 +14,9 @@ """Utilities for Prometheus Metrics Collection.""" import time from dataclasses import dataclass, field -from enum import Enum from typing import Dict, List, Optional, Union +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.metrics.utils import exponential_buckets, generate_buckets from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var @@ -34,6 +34,7 @@ class TimeStats: Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion """ + disagg_mode: DisaggregationMode = DisaggregationMode.NULL lb_entry_time: float = 0.0 wait_queue_entry_time: float = 0.0 forward_entry_time: float = 0.0 @@ -43,20 +44,11 @@ class TimeStats: decode_prealloc_queue_entry_time: float = 0.0 decode_transfer_queue_entry_time: float = 0.0 - class RequestType(Enum): - UNIFIED = "unified" - PREFILL = "prefill" - DECODE = "decode" - INVALID = "invalid" - def get_queueing_time(self) -> float: return self.forward_entry_time - self.wait_queue_entry_time - def __str__(self) -> str: - # if unified - _type = self.get_type() - - if _type == self.RequestType.UNIFIED: + def convert_to_duration(self) -> str: + if self.disagg_mode == DisaggregationMode.NULL: queue_duration = self.forward_entry_time - self.wait_queue_entry_time forward_duration = self.completion_time - self.forward_entry_time @@ -65,30 +57,28 @@ class TimeStats: queue_duration >= 0 and forward_duration >= 0 ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" - return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}" - elif _type == self.RequestType.PREFILL: + return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}" + elif self.disagg_mode == DisaggregationMode.PREFILL: bootstrap_duration = ( self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time ) - queue_duration = self.forward_entry_time - self.wait_queue_entry_time - forward_duration = self.completion_time - self.forward_entry_time if SGLANG_TEST_REQUEST_TIME_STATS: - assert ( - bootstrap_duration >= 0 - and queue_duration >= 0 - and forward_duration >= 0 - ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" - return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}" - # if decode - elif _type == self.RequestType.DECODE: + if self.wait_queue_entry_time > 0: + assert ( + bootstrap_duration >= 0 + and queue_duration >= 0 + and forward_duration >= 0 + ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + + return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}" + elif self.disagg_mode == DisaggregationMode.DECODE: prealloc_duration = ( self.decode_transfer_queue_entry_time - self.decode_prealloc_queue_entry_time ) - transfer_duration = ( self.wait_queue_entry_time - self.decode_transfer_queue_entry_time ) @@ -96,42 +86,30 @@ class TimeStats: forward_duration = self.completion_time - self.forward_entry_time if SGLANG_TEST_REQUEST_TIME_STATS: - assert ( - prealloc_duration >= 0 - and transfer_duration >= 0 - and queue_duration >= 0 - and forward_duration >= 0 - ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + if self.wait_queue_entry_time > 0: + assert ( + prealloc_duration >= 0 + and transfer_duration >= 0 + and queue_duration >= 0 + and forward_duration >= 0 + ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}" - return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}" + return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}" else: - return "Invalid Time Stats" + return "Unknown Time Stats" def format_duration(self, duration: float) -> str: return f"{duration * 1e3:.2f}ms" - def get_type(self) -> RequestType: - """Determine the type of request based on timestamp values.""" - if ( - self.prefill_bootstrap_queue_entry_time == 0.0 - and self.prefill_transfer_queue_entry_time == 0.0 - and self.decode_prealloc_queue_entry_time == 0.0 - and self.decode_transfer_queue_entry_time == 0.0 - ): - return self.RequestType.UNIFIED - elif ( - self.prefill_bootstrap_queue_entry_time > 0.0 - and self.prefill_transfer_queue_entry_time > 0.0 - ): - return self.RequestType.PREFILL - elif ( - self.decode_prealloc_queue_entry_time > 0.0 - and self.decode_transfer_queue_entry_time > 0.0 - and self.wait_queue_entry_time > 0.0 - ): - return self.RequestType.DECODE + def disagg_mode_str(self) -> str: + if self.disagg_mode == DisaggregationMode.NULL: + return "unified" + elif self.disagg_mode == DisaggregationMode.DECODE: + return "decode" + elif self.disagg_mode == DisaggregationMode.PREFILL: + return "prefill" else: - return self.RequestType.INVALID + return "unknown" @dataclass @@ -145,12 +123,15 @@ class SchedulerStats: num_queue_reqs: int = 0 num_grammar_queue_reqs: int = 0 num_running_reqs_offline_batch: int = 0 - avg_request_queue_latency: float = 0.0 cache_hit_rate: float = 0.0 # Speculative decoding spec_accept_length: float = 0.0 + # Retract + num_retracted_reqs: int = 0 + num_paused_reqs: int = 0 + # PD disaggregation num_prefill_prealloc_queue_reqs: int = 0 num_prefill_inflight_queue_reqs: int = 0 @@ -159,11 +140,6 @@ class SchedulerStats: kv_transfer_speed_gb_s: float = 0.0 kv_transfer_latency_ms: float = 0.0 - # Retract - total_retracted_reqs: int = 0 - num_retracted_reqs: int = 0 - num_paused_reqs: int = 0 - # Utilization utilization: float = 0.0 max_running_requests_under_SLO: Optional[int] = None @@ -230,12 +206,6 @@ class SchedulerMetricsCollector: labelnames=labels.keys(), multiprocess_mode="mostrecent", ) - self.avg_request_queue_latency = Gauge( - name="sglang:avg_request_queue_latency", - documentation="The average request queue latency for the last batch of requests in seconds.", - labelnames=labels.keys(), - multiprocess_mode="mostrecent", - ) self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", documentation="The prefix cache hit rate.", @@ -251,6 +221,18 @@ class SchedulerMetricsCollector: multiprocess_mode="mostrecent", ) + # Retract + self.num_retracted_reqs = Gauge( + name="sglang:num_retracted_reqs", + documentation="The number of retracted requests.", + labelnames=labels.keys(), + ) + self.num_paused_reqs = Gauge( + name="sglang:num_paused_reqs", + documentation="The number of paused requests by async weight sync.", + labelnames=labels.keys(), + ) + # PD disaggregation self.num_prefill_prealloc_queue_reqs = Gauge( name="sglang:num_prefill_prealloc_queue_reqs", @@ -299,24 +281,6 @@ class SchedulerMetricsCollector: multiprocess_mode="mostrecent", ) - # Retract - self.total_retracted_reqs = Gauge( - name="sglang:total_retracted_reqs", - documentation="The total number of retracted requests due to kvcache full.", - labelnames=labels.keys(), - multiprocess_mode="mostrecent", - ) - self.num_retracted_reqs = Gauge( - name="sglang:num_retracted_reqs", - documentation="The number of retracted requests.", - labelnames=labels.keys(), - ) - self.num_paused_reqs = Gauge( - name="sglang:num_paused_reqs", - documentation="The number of paused requests by async weight sync.", - labelnames=labels.keys(), - ) - # Utilization self.utilization = Gauge( name="sglang:utilization", @@ -347,7 +311,7 @@ class SchedulerMetricsCollector: # Additional queueing time histogram self.queue_time = Histogram( - name="sglang:queue_time_s", + name="sglang:queue_time_seconds", documentation="Histogram of queueing time in seconds.", labelnames=labels.keys(), buckets=[ @@ -513,8 +477,8 @@ class SchedulerMetricsCollector: buckets=tree_traversal_time_buckets, ) - self.request_latency_seconds = Histogram( - name="sglang:request_latency_seconds", + self.per_stage_req_latency_seconds = Histogram( + name="sglang:per_stage_req_latency_seconds", documentation="The latency of each stage of requests.", # captures latency in range [1ms - ~1191s] buckets=exponential_buckets(start=0.001, width=1.62, length=30), @@ -525,7 +489,7 @@ class SchedulerMetricsCollector: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) - def log_histogram(self, histogram, data: Union[int, float]) -> None: + def _log_histogram(self, histogram, data: Union[int, float]) -> None: histogram.labels(**self.labels).observe(data) def increment_bootstrap_failed_reqs(self) -> None: @@ -534,9 +498,12 @@ class SchedulerMetricsCollector: def increment_transfer_failed_reqs(self) -> None: self.num_transfer_failed_reqs.labels(**self.labels).inc(1) - def observe_request_latency_seconds(self, stage: str, latency: float) -> None: + def observe_per_stage_req_latency(self, stage: str, latency: float) -> None: labels_with_stage = {**self.labels, "stage": stage} - self.request_latency_seconds.labels(**labels_with_stage).observe(latency) + self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency) + + def observe_queue_time(self, latency: float) -> None: + self._log_histogram(self.queue_time, latency) def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.num_running_reqs, stats.num_running_reqs) @@ -550,7 +517,6 @@ class SchedulerMetricsCollector: self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch ) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) - self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency) # Speculative decoding self._log_gauge(self.spec_accept_length, stats.spec_accept_length) @@ -572,7 +538,6 @@ class SchedulerMetricsCollector: self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms) # Retract - self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs) self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs) self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs) @@ -596,19 +561,19 @@ class SchedulerMetricsCollector: def log_grammar_stats(self, grammar_stats) -> None: # Duck-typed GrammarStats to avoid cross-package dependency if getattr(grammar_stats, "compilation_time", None) is not None: - self.log_histogram( + self._log_histogram( self.grammar_compilation_time, grammar_stats.compilation_time ) if getattr(grammar_stats, "schema_count", None) is not None: - self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count) + self._log_histogram(self.grammar_schema_count, grammar_stats.schema_count) if getattr(grammar_stats, "ebnf_size", None) is not None: - self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size) + self._log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size) tree_times = getattr(grammar_stats, "tree_traversal_time", None) if tree_times: max_time = max(tree_times) avg_time = sum(tree_times) / len(tree_times) - self.log_histogram(self.grammar_tree_traversal_time_max, max_time) - self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time) + self._log_histogram(self.grammar_tree_traversal_time_max, max_time) + self._log_histogram(self.grammar_tree_traversal_time_avg, avg_time) if getattr(grammar_stats, "is_cache_hit", False): self.num_grammar_cache_hit.labels(**self.labels).inc(1) if getattr(grammar_stats, "is_grammar_aborted", False): @@ -714,7 +679,7 @@ class TokenizerMetricsCollector: ) self.num_aborted_requests_total = Counter( - name="sglang:num_aborted_requests", + name="sglang:num_aborted_requests_total", documentation="Number of requests aborted.", labelnames=labels.keys(), ) @@ -801,7 +766,7 @@ class TokenizerMetricsCollector: buckets=bucket_time_to_first_token, ) - self.histogram_inter_token_latency_seconds = Histogram( + self.histogram_inter_token_latency = Histogram( name="sglang:inter_token_latency_seconds", documentation="Histogram of inter-token latency in seconds.", labelnames=labels.keys(), @@ -815,14 +780,6 @@ class TokenizerMetricsCollector: buckets=bucket_e2e_request_latency, ) - # Offline batch specific TTFB histogram - self.histogram_time_to_first_token_offline_batch = Histogram( - name="sglang:time_to_first_token_seconds_offline_batch", - documentation="Histogram of time to first token in seconds for offline batch requests.", - labelnames=labels.keys(), - buckets=bucket_time_to_first_token, - ) - def observe_one_finished_request( self, labels: Dict[str, str], @@ -846,15 +803,8 @@ class TokenizerMetricsCollector: float(generation_tokens) ) - def observe_time_to_first_token( - self, labels: Dict[str, str], value: float, type: str = "" - ): - if type == "batch": - self.histogram_time_to_first_token_offline_batch.labels(**labels).observe( - value - ) - else: - self.histogram_time_to_first_token.labels(**labels).observe(value) + def observe_time_to_first_token(self, labels: Dict[str, str], value: float): + self.histogram_time_to_first_token.labels(**labels).observe(value) def check_time_to_first_token_straggler(self, value: float) -> bool: his = self.histogram_time_to_first_token.labels(**self.labels) @@ -876,7 +826,7 @@ class TokenizerMetricsCollector: # A faster version of the Histogram::observe which observes multiple values at the same time. # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639 - his = self.histogram_inter_token_latency_seconds.labels(**labels) + his = self.histogram_inter_token_latency.labels(**labels) his._sum.inc(internval) for i, bound in enumerate(his._upper_bounds): @@ -884,8 +834,8 @@ class TokenizerMetricsCollector: his._buckets[i].inc(num_new_tokens) break - def observe_one_aborted_request(self): - self.num_aborted_requests_total.labels(**self.labels).inc(1) + def observe_one_aborted_request(self, labels: Dict[str, str]): + self.num_aborted_requests_total.labels(**labels).inc(1) @dataclass diff --git a/python/sglang/srt/tracing/trace.py b/python/sglang/srt/tracing/trace.py index b4ccbfa9f..f637a8d77 100644 --- a/python/sglang/srt/tracing/trace.py +++ b/python/sglang/srt/tracing/trace.py @@ -15,7 +15,6 @@ from __future__ import annotations -import ctypes import logging import os import random @@ -23,7 +22,10 @@ import threading import time import uuid from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import Req logger = logging.getLogger(__name__) opentelemetry_imported = False @@ -407,9 +409,11 @@ def trace_slice_start( ts: Optional[int] = None, anonymous: bool = False, ): + if not tracing_enabled: + return rid = str(rid) - if not tracing_enabled or rid not in reqs_context: + if rid not in reqs_context: return pid = threading.get_native_id() @@ -458,8 +462,11 @@ def trace_slice_end( auto_next_anon: bool = False, thread_finish_flag: bool = False, ): + if not tracing_enabled: + return + rid = str(rid) - if not tracing_enabled or rid not in reqs_context: + if rid not in reqs_context: return pid = threading.get_native_id() @@ -512,10 +519,13 @@ trace_slice = trace_slice_end # Add event to the current slice on the same thread with the same rid. def trace_event(name: str, rid: str, ts: Optional[int] = None): - if not tracing_enabled or rid not in reqs_context: + if not tracing_enabled: return rid = str(rid) + if rid not in reqs_context: + return + pid = threading.get_native_id() if pid not in reqs_context[rid].threads_context: return @@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None): # Add attrs to the current slice on the same thread with the same rid. def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): - if not tracing_enabled or rid not in reqs_context: + if not tracing_enabled: return rid = str(rid) + if rid not in reqs_context: + return + pid = threading.get_native_id() if pid not in reqs_context[rid].threads_context: return @@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): slice_info = thread_context.cur_slice_stack[-1] slice_info.span.set_attributes(attrs) + + +def trace_slice_batch( + name: str, + reqs: List[Req], +): + for req in reqs: + trace_slice( + name, + req.rid, + auto_next_anon=not req.finished(), + thread_finish_flag=req.finished(), + )