Fix metrics and request tracing (TimeStats) (#11123)
This commit is contained in:
@@ -14,18 +14,17 @@ classifiers = [
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
"requests",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"IPython",
|
||||
"setproctitle",
|
||||
"aiohttp",
|
||||
"anthropic>=0.20.0",
|
||||
"blobfile==3.0.0",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"cuda-python",
|
||||
"datasets",
|
||||
"einops",
|
||||
"fastapi",
|
||||
"flashinfer_python==0.4.0rc3",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
"interegular",
|
||||
@@ -33,8 +32,10 @@ dependencies = [
|
||||
"modelscope",
|
||||
"msgspec",
|
||||
"ninja",
|
||||
"openai==1.99.1",
|
||||
"numpy",
|
||||
"nvidia-cutlass-dsl==4.2.1",
|
||||
"openai-harmony==0.0.4",
|
||||
"openai==1.99.1",
|
||||
"orjson",
|
||||
"outlines==0.1.11",
|
||||
"packaging",
|
||||
@@ -42,32 +43,30 @@ dependencies = [
|
||||
"pillow",
|
||||
"prometheus-client>=0.20.0",
|
||||
"psutil",
|
||||
"py-spy",
|
||||
"pybase64",
|
||||
"pydantic",
|
||||
"pynvml",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"requests",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"setproctitle",
|
||||
"sgl-kernel==0.3.13",
|
||||
"soundfile==0.13.1",
|
||||
"timm==1.0.16",
|
||||
"tiktoken",
|
||||
"timm==1.0.16",
|
||||
"torch==2.8.0",
|
||||
"torch_memory_saver==0.0.8",
|
||||
"torchao==0.9.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers==4.56.1",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.24",
|
||||
"sgl-kernel==0.3.13",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.4.0rc3",
|
||||
"openai==1.99.1",
|
||||
"tiktoken",
|
||||
"anthropic>=0.20.0",
|
||||
"torch_memory_saver==0.0.8",
|
||||
"nvidia-cutlass-dsl==4.2.1",
|
||||
"xgrammar==0.1.24"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -79,15 +78,15 @@ test = [
|
||||
"matplotlib",
|
||||
"pandas",
|
||||
"peft",
|
||||
"sentence_transformers",
|
||||
"pytest",
|
||||
"sentence_transformers",
|
||||
"tabulate",
|
||||
]
|
||||
tracing = [
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-sdk",
|
||||
]
|
||||
all = ["sglang[test]", "sglang[decord]"]
|
||||
blackwell = ["sglang[test]", "sglang[decord]"]
|
||||
|
||||
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
@@ -422,9 +423,13 @@ class DecodePreallocQueue:
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||
|
||||
preallocated_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
||||
time.perf_counter()
|
||||
)
|
||||
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
@@ -625,6 +630,7 @@ class DecodeTransferQueue:
|
||||
decode_req.req.output_topk_p = output_topk_p
|
||||
decode_req.req.output_topk_index = output_topk_index
|
||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||
|
||||
if decode_req.req.return_logprob:
|
||||
decode_req.req.output_token_logprobs_val.append(
|
||||
output_token_logprobs_val[0].item()
|
||||
@@ -645,10 +651,17 @@ class DecodeTransferQueue:
|
||||
|
||||
if hasattr(decode_req.kv_receiver, "clear"):
|
||||
decode_req.kv_receiver.clear()
|
||||
decode_req.kv_receiver = None
|
||||
|
||||
indices_to_remove.add(i)
|
||||
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
|
||||
|
||||
# special handling for sampling_params.max_new_tokens == 1
|
||||
if decode_req.req.sampling_params.max_new_tokens == 1:
|
||||
# finish immediately
|
||||
decode_req.req.time_stats.forward_entry_time = (
|
||||
decode_req.req.time_stats.completion_time
|
||||
) = time.perf_counter()
|
||||
decode_req.req.check_finished()
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
@@ -656,8 +669,6 @@ class DecodeTransferQueue:
|
||||
self.tree_cache.cache_finished_req(decode_req.req)
|
||||
else:
|
||||
transferred_reqs.append(decode_req.req)
|
||||
|
||||
indices_to_remove.add(i)
|
||||
elif poll in [
|
||||
KVPoll.Bootstrapping,
|
||||
KVPoll.WaitingForInput,
|
||||
@@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
for req in can_run_list:
|
||||
req.time_stats.forward_entry_time = time.perf_counter()
|
||||
|
||||
# construct a schedule batch with those requests and mark as decode
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
|
||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
@@ -263,9 +264,10 @@ class PrefillBootstrapQueue:
|
||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||
|
||||
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
||||
bootstrapped_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
||||
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
@@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
for i, (req, next_token_id) in enumerate(
|
||||
zip(batch.reqs, next_token_ids, strict=True)
|
||||
):
|
||||
req: Req
|
||||
if req.is_chunked <= 0:
|
||||
# There is no output_ids for prefill
|
||||
req.output_ids.append(next_token_id)
|
||||
@@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
self.send_kv_chunk(req, last_chunk=True)
|
||||
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
|
||||
|
||||
if req.grammar is not None:
|
||||
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
||||
@@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
else:
|
||||
assert False, f"Unexpected polling state {poll=}"
|
||||
|
||||
for req in done_reqs:
|
||||
req.time_stats.completion_time = time.perf_counter()
|
||||
|
||||
# Stream requests which have finished transfer
|
||||
self.stream_output(
|
||||
done_reqs,
|
||||
|
||||
@@ -5,7 +5,7 @@ import random
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -41,7 +41,7 @@ import time
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
||||
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
@@ -452,6 +453,7 @@ class Req:
|
||||
bootstrap_host: Optional[str] = None,
|
||||
bootstrap_port: Optional[int] = None,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
disagg_mode: Optional[DisaggregationMode] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
priority: Optional[int] = None,
|
||||
@@ -628,10 +630,8 @@ class Req:
|
||||
|
||||
# For metrics
|
||||
self.metrics_collector = metrics_collector
|
||||
self.time_stats: TimeStats = TimeStats()
|
||||
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
||||
self.has_log_time_stats: bool = False
|
||||
self.queue_time_start = None
|
||||
self.queue_time_end = None
|
||||
self.last_tic = time.monotonic()
|
||||
|
||||
# For disaggregation
|
||||
@@ -668,9 +668,9 @@ class Req:
|
||||
def add_latency(self, stage: RequestStage):
|
||||
if self.metrics_collector is None:
|
||||
return
|
||||
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
|
||||
|
||||
now = time.monotonic()
|
||||
self.metrics_collector.observe_request_latency_seconds(
|
||||
self.metrics_collector.observe_per_stage_req_latency(
|
||||
stage.value, now - self.last_tic
|
||||
)
|
||||
self.last_tic = now
|
||||
@@ -834,10 +834,10 @@ class Req:
|
||||
return
|
||||
|
||||
if self.bootstrap_room is not None:
|
||||
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
||||
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
||||
else:
|
||||
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
||||
logger.info(f"{prefix}: {self.time_stats}")
|
||||
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
||||
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
|
||||
self.has_log_time_stats = True
|
||||
|
||||
def set_finish_with_abort(self, error_msg: str):
|
||||
@@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
) / total_max_new_tokens
|
||||
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
return retracted_reqs, new_estimate_ratio, []
|
||||
|
||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||
req = self.reqs[idx]
|
||||
|
||||
@@ -276,9 +276,13 @@ class SchedulePolicy:
|
||||
) -> None:
|
||||
"""Sorts the waiting queue based on the request priority then received titmestamp."""
|
||||
if schedule_low_priority_values_first:
|
||||
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
|
||||
waiting_queue.sort(
|
||||
key=lambda x: (x.priority, x.time_stats.wait_queue_entry_time)
|
||||
)
|
||||
else:
|
||||
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
|
||||
waiting_queue.sort(
|
||||
key=lambda x: (-x.priority, x.time_stats.wait_queue_entry_time)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
||||
@@ -642,12 +646,12 @@ class PrefillAdder:
|
||||
if server_args.schedule_low_priority_values_first:
|
||||
sorted_running_reqs = sorted(
|
||||
self.running_batch.reqs,
|
||||
key=lambda x: (-x.priority, -x.queue_time_start),
|
||||
key=lambda x: (-x.priority, -x.time_stats.wait_queue_entry_time),
|
||||
)
|
||||
else:
|
||||
sorted_running_reqs = sorted(
|
||||
self.running_batch.reqs,
|
||||
key=lambda x: (x.priority, -x.queue_time_start),
|
||||
key=lambda x: (x.priority, -x.time_stats.wait_queue_entry_time),
|
||||
)
|
||||
preemptible_reqs = []
|
||||
min_tokens_to_remove = (
|
||||
|
||||
@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.tracing.trace import (
|
||||
process_tracing_init,
|
||||
trace_event,
|
||||
trace_set_proc_propagate_context,
|
||||
trace_set_thread_info,
|
||||
trace_slice,
|
||||
trace_slice_batch,
|
||||
trace_slice_end,
|
||||
trace_slice_start,
|
||||
)
|
||||
@@ -263,6 +262,7 @@ class Scheduler(
|
||||
server_args.enable_metrics_for_all_schedulers
|
||||
)
|
||||
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
|
||||
self.enable_trace = server_args.enable_trace
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
@@ -899,10 +899,6 @@ class Scheduler(
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
for req in batch.reqs:
|
||||
trace_event("schedule", req.rid)
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
@@ -924,10 +920,6 @@ class Scheduler(
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
for req in batch.reqs:
|
||||
trace_event("schedule", req.rid)
|
||||
|
||||
if batch:
|
||||
batch.launch_done = threading.Event()
|
||||
result = self.run_batch(batch)
|
||||
@@ -1192,10 +1184,13 @@ class Scheduler(
|
||||
src=self.tp_group.ranks[0],
|
||||
)
|
||||
|
||||
for req in recv_reqs:
|
||||
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
|
||||
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
||||
trace_slice_start("", req.rid, anonymous=True)
|
||||
if self.enable_trace:
|
||||
for req in recv_reqs:
|
||||
if isinstance(
|
||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
):
|
||||
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
||||
trace_slice_start("", req.rid, anonymous=True)
|
||||
|
||||
return recv_reqs
|
||||
|
||||
@@ -1277,6 +1272,7 @@ class Scheduler(
|
||||
bootstrap_host=recv_req.bootstrap_host,
|
||||
bootstrap_port=recv_req.bootstrap_port,
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
disagg_mode=self.disaggregation_mode,
|
||||
data_parallel_rank=recv_req.data_parallel_rank,
|
||||
vocab_size=self.model_config.vocab_size,
|
||||
priority=recv_req.priority,
|
||||
@@ -1403,7 +1399,6 @@ class Scheduler(
|
||||
req.set_finish_with_abort(error_msg)
|
||||
|
||||
if add_to_grammar_queue:
|
||||
req.queue_time_start = time.perf_counter()
|
||||
self.grammar_queue.append(req)
|
||||
else:
|
||||
self._add_request_to_queue(req)
|
||||
@@ -1419,23 +1414,6 @@ class Scheduler(
|
||||
for tokenized_req in recv_req:
|
||||
self.handle_generate_request(tokenized_req)
|
||||
|
||||
def _add_request_to_queue(self, req: Req):
|
||||
req.queue_time_start = time.perf_counter()
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self._prefetch_kvcache(req)
|
||||
self.disagg_prefill_bootstrap_queue.add(
|
||||
req, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
else:
|
||||
self._set_or_validate_priority(req)
|
||||
if self._abort_on_queued_limit(req):
|
||||
return
|
||||
self._prefetch_kvcache(req)
|
||||
self.waiting_queue.append(req)
|
||||
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||
|
||||
def _prefetch_kvcache(self, req: Req):
|
||||
if self.enable_hicache_storage:
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
@@ -1449,19 +1427,27 @@ class Scheduler(
|
||||
req.rid, req.last_host_node, new_input_tokens, last_hash
|
||||
)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(
|
||||
reqs, self.model_config.num_key_value_heads
|
||||
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.NULL:
|
||||
self._set_or_validate_priority(req)
|
||||
if self._abort_on_queued_limit(req):
|
||||
return
|
||||
self._prefetch_kvcache(req)
|
||||
self.waiting_queue.append(req)
|
||||
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
||||
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self._prefetch_kvcache(req)
|
||||
self.disagg_prefill_bootstrap_queue.add(
|
||||
req, self.model_config.num_key_value_heads
|
||||
)
|
||||
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
||||
if not is_retracted:
|
||||
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
||||
else:
|
||||
for req in reqs:
|
||||
self._set_or_validate_priority(req)
|
||||
if not self._abort_on_queued_limit(req):
|
||||
self.waiting_queue.append(req)
|
||||
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
||||
|
||||
def _set_or_validate_priority(self, req: Req):
|
||||
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
||||
@@ -1500,7 +1486,7 @@ class Scheduler(
|
||||
direction = 1 if self.schedule_low_priority_values_first else -1
|
||||
key_fn = lambda item: (
|
||||
direction * item[1].priority,
|
||||
item[1].queue_time_start,
|
||||
item[1].time_stats.wait_queue_entry_time,
|
||||
)
|
||||
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
||||
abort_existing_req = (
|
||||
@@ -1902,14 +1888,14 @@ class Scheduler(
|
||||
if self.enable_metrics:
|
||||
# only record queue time when enable_metrics is True to avoid overhead
|
||||
for req in can_run_list:
|
||||
req.queue_time_end = time.perf_counter()
|
||||
req.add_latency(RequestStage.PREFILL_WAITING)
|
||||
|
||||
self.waiting_queue = [
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
if adder.preempt_list:
|
||||
self._extend_requests_to_queue(adder.preempt_list)
|
||||
for req in adder.preempt_list:
|
||||
self._add_request_to_queue(req)
|
||||
|
||||
if adder.new_chunked_req is not None:
|
||||
assert self.chunked_req is None
|
||||
@@ -1920,7 +1906,16 @@ class Scheduler(
|
||||
|
||||
# Print stats
|
||||
if self.current_scheduler_metrics_enabled():
|
||||
self.log_prefill_stats(adder, can_run_list, running_bs)
|
||||
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
||||
|
||||
for req in can_run_list:
|
||||
if req.time_stats.forward_entry_time == 0:
|
||||
# Avoid update chunked request many times
|
||||
req.time_stats.forward_entry_time = time.perf_counter()
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector.observe_queue_time(
|
||||
req.time_stats.get_queueing_time(),
|
||||
)
|
||||
|
||||
# Create a new batch
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
@@ -1975,19 +1970,25 @@ class Scheduler(
|
||||
TEST_RETRACT and batch.batch_size() > 10
|
||||
):
|
||||
old_ratio = self.new_token_ratio
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
||||
num_retracted_reqs = len(retracted_reqs)
|
||||
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
||||
self.server_args
|
||||
)
|
||||
self.num_retracted_reqs = len(retracted_reqs)
|
||||
self.new_token_ratio = new_token_ratio
|
||||
for req in reqs_to_abort:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(req.rid, abort_reason=req.to_abort_message)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"KV cache pool is full. Retract requests. "
|
||||
f"#retracted_reqs: {num_retracted_reqs}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
||||
)
|
||||
|
||||
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
||||
self.total_retracted_reqs += num_retracted_reqs
|
||||
for req in retracted_reqs:
|
||||
self._add_request_to_queue(req, is_retracted=True)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
@@ -2086,23 +2087,14 @@ class Scheduler(
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result, launch_done)
|
||||
for req in batch.reqs:
|
||||
trace_slice(
|
||||
"decode loop",
|
||||
req.rid,
|
||||
auto_next_anon=not req.finished(),
|
||||
thread_finish_flag=req.finished(),
|
||||
)
|
||||
if self.enable_trace:
|
||||
trace_slice_batch("decode loop", batch.reqs)
|
||||
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result, launch_done)
|
||||
for req in batch.reqs:
|
||||
trace_slice(
|
||||
"prefill",
|
||||
req.rid,
|
||||
auto_next_anon=not req.finished(),
|
||||
thread_finish_flag=req.finished(),
|
||||
)
|
||||
if self.enable_trace:
|
||||
trace_slice_batch("prefill", batch.reqs)
|
||||
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
@@ -2261,12 +2253,13 @@ class Scheduler(
|
||||
if req.finished(): # It is aborted by AbortReq
|
||||
num_ready_reqs += 1
|
||||
continue
|
||||
|
||||
req.grammar = req.grammar.result(timeout=0.03)
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||
req.set_finish_with_abort(
|
||||
f"Invalid grammar request: {req.grammar_key=}"
|
||||
)
|
||||
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
||||
req.set_finish_with_abort(error_msg)
|
||||
|
||||
num_ready_reqs += 1
|
||||
except futures._base.TimeoutError:
|
||||
req.grammar_wait_ct += 1
|
||||
@@ -2298,9 +2291,8 @@ class Scheduler(
|
||||
req.grammar = req.grammar.result()
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||
req.set_finish_with_abort(
|
||||
f"Invalid grammar request: {req.grammar_key=}"
|
||||
)
|
||||
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
||||
req.set_finish_with_abort(error_msg)
|
||||
else:
|
||||
num_ready_reqs_max = num_ready_reqs
|
||||
num_timeout_reqs_max = num_timeout_reqs
|
||||
@@ -2308,12 +2300,14 @@ class Scheduler(
|
||||
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar.cancel()
|
||||
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||
req.set_finish_with_abort(error_msg)
|
||||
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
||||
|
||||
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
||||
|
||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||
for req in self.grammar_queue[:num_ready_reqs]:
|
||||
self._add_request_to_queue(req)
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
||||
@@ -2795,17 +2789,11 @@ def run_scheduler_process(
|
||||
pipe_writer,
|
||||
balance_meta: Optional[DPBalanceMeta] = None,
|
||||
):
|
||||
if server_args.enable_trace:
|
||||
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||
if server_args.disaggregation_mode == "null":
|
||||
thread_label = "Scheduler"
|
||||
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
||||
|
||||
if (numa_node := server_args.numa_node) is not None:
|
||||
numa_bind_to_node(numa_node[gpu_id])
|
||||
|
||||
# Generate the prefix
|
||||
# Generate the logger prefix
|
||||
prefix = ""
|
||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||
if dp_rank is not None:
|
||||
prefix += f" DP{dp_rank}"
|
||||
if server_args.tp_size > 1:
|
||||
@@ -2821,10 +2809,6 @@ def run_scheduler_process(
|
||||
kill_itself_when_parent_died()
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=prefix)
|
||||
suppress_other_loggers()
|
||||
@@ -2832,6 +2816,15 @@ def run_scheduler_process(
|
||||
# Set cpu affinity to this gpu process
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
if (numa_node := server_args.numa_node) is not None:
|
||||
numa_bind_to_node(numa_node[gpu_id])
|
||||
|
||||
# Set up tracing
|
||||
if server_args.enable_trace:
|
||||
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||
if server_args.disaggregation_mode == "null":
|
||||
thread_label = "Scheduler"
|
||||
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
||||
|
||||
# Create a scheduler and run the event loop
|
||||
try:
|
||||
|
||||
@@ -47,8 +47,11 @@ class SchedulerMetricsMixin:
|
||||
self.spec_num_total_forward_ct = 0
|
||||
self.cum_spec_accept_length = 0
|
||||
self.cum_spec_accept_count = 0
|
||||
self.total_retracted_reqs = 0
|
||||
self.kv_transfer_speed_gb_s: float = 0.0
|
||||
self.kv_transfer_latency_ms: float = 0.0
|
||||
|
||||
self.stats = SchedulerStats()
|
||||
|
||||
if self.enable_metrics:
|
||||
engine_type = "unified"
|
||||
labels = {
|
||||
@@ -82,12 +85,14 @@ class SchedulerMetricsMixin:
|
||||
adder: PrefillAdder,
|
||||
can_run_list: List[Req],
|
||||
running_bs: int,
|
||||
running_bs_offline_batch: int,
|
||||
):
|
||||
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
||||
self.last_prefill_stats_tic = time.perf_counter()
|
||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||
self.last_prefill_tokens = adder.log_input_tokens
|
||||
|
||||
# TODO: generalize this for various memory pools
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
@@ -101,51 +106,53 @@ class SchedulerMetricsMixin:
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
token_msg = (
|
||||
token_usage_msg = (
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_msg = f"token usage: {token_usage:.2f}, "
|
||||
token_usage_msg = f"token usage: {token_usage:.2f}, "
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
f = (
|
||||
f"Prefill batch. "
|
||||
f"#new-seq: {num_new_seq}, "
|
||||
f"#new-seq: {len(can_run_list)}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"{token_msg}"
|
||||
f"{token_usage_msg}"
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
||||
f += f"#queue-req: {len(self.waiting_queue)}, "
|
||||
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
||||
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
||||
else:
|
||||
f += f"#running-req: {running_bs}, "
|
||||
f += f"#queue-req: {len(self.waiting_queue)}, "
|
||||
f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
||||
f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
|
||||
|
||||
logger.info(f)
|
||||
|
||||
if self.enable_metrics:
|
||||
# Basics
|
||||
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
||||
|
||||
cache_hit_rate = (
|
||||
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
||||
)
|
||||
|
||||
self.stats.num_running_reqs = running_bs
|
||||
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.token_usage = token_usage
|
||||
if self.is_hybrid:
|
||||
self.stats.swa_token_usage = swa_token_usage
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.stats.cache_hit_rate = cache_hit_rate
|
||||
|
||||
total_queue_latency = 0
|
||||
for req in can_run_list:
|
||||
total_queue_latency += req.queue_time_end - req.queue_time_start
|
||||
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
||||
# Retract
|
||||
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
||||
self.stats.num_paused_reqs = self.num_paused_reqs
|
||||
self.num_retracted_reqs = self.num_paused_reqs = 0
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||
self.disagg_prefill_bootstrap_queue.queue
|
||||
@@ -153,7 +160,18 @@ class SchedulerMetricsMixin:
|
||||
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||
self.disagg_prefill_inflight_queue
|
||||
)
|
||||
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
|
||||
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||
self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
self.stats.num_decode_transfer_queue_reqs = len(
|
||||
self.disagg_decode_transfer_queue.queue
|
||||
)
|
||||
|
||||
# Others
|
||||
self.calculate_utilization()
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._emit_kv_metrics()
|
||||
self._publish_kv_events()
|
||||
@@ -166,8 +184,12 @@ class SchedulerMetricsMixin:
|
||||
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
||||
self.last_decode_stats_tic = time.perf_counter()
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(batch.reqs)
|
||||
num_running_reqs_offline_batch = 0
|
||||
|
||||
# TODO: generalize this for various memory pools
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
@@ -181,7 +203,7 @@ class SchedulerMetricsMixin:
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
token_msg = (
|
||||
token_usage_msg = (
|
||||
f"#full token: {full_num_used}, "
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"#swa token: {swa_num_used}, "
|
||||
@@ -189,14 +211,14 @@ class SchedulerMetricsMixin:
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
||||
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
self.step_time_dict[num_running_reqs].append(
|
||||
gap_latency / self.server_args.decode_log_interval
|
||||
)
|
||||
|
||||
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
||||
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
@@ -208,41 +230,66 @@ class SchedulerMetricsMixin:
|
||||
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||
msg += f"accept len: {spec_accept_length:.2f}, "
|
||||
cache_hit_rate = 0.0
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
|
||||
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
|
||||
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||
|
||||
msg += (
|
||||
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
|
||||
f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
|
||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
)
|
||||
|
||||
logger.info(msg)
|
||||
if self.enable_metrics:
|
||||
# Basics
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.cache_hit_rate = 0.0
|
||||
self.stats.token_usage = token_usage
|
||||
if self.is_hybrid:
|
||||
self.stats.swa_token_usage = swa_token_usage
|
||||
self.stats.gen_throughput = self.last_gen_throughput
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.stats.cache_hit_rate = cache_hit_rate
|
||||
self.stats.spec_accept_length = spec_accept_length
|
||||
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
||||
self.stats.avg_request_queue_latency = 0.0
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
|
||||
# Retract
|
||||
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
||||
self.stats.num_paused_reqs = self.num_paused_reqs
|
||||
self.num_retracted_reqs = self.num_paused_reqs = 0
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||
self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||
self.disagg_prefill_inflight_queue
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||
self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
self.stats.num_decode_transfer_queue_reqs = len(
|
||||
self.disagg_decode_transfer_queue.queue
|
||||
)
|
||||
|
||||
# Others
|
||||
self.calculate_utilization()
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._emit_kv_metrics()
|
||||
self._publish_kv_events()
|
||||
|
||||
def _emit_kv_metrics(self: Scheduler):
|
||||
if not self.enable_kv_cache_events:
|
||||
return
|
||||
|
||||
kv_metrics = KvMetrics()
|
||||
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
||||
kv_metrics.request_total_slots = self.max_running_requests
|
||||
@@ -259,11 +306,13 @@ class SchedulerMetricsMixin:
|
||||
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
||||
|
||||
def _publish_kv_events(self: Scheduler):
|
||||
if self.enable_kv_cache_events:
|
||||
events = self.tree_cache.take_events()
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
if not self.enable_kv_cache_events:
|
||||
return
|
||||
|
||||
events = self.tree_cache.take_events()
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
def maybe_update_dp_balance_data(
|
||||
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
||||
@@ -349,3 +398,17 @@ class SchedulerMetricsMixin:
|
||||
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
||||
meta.set_shared_onfly_info(onfly_list)
|
||||
meta.set_shared_local_tokens(local_tokens)
|
||||
|
||||
def calculate_utilization(self):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.stats.utilization = -1
|
||||
else:
|
||||
if (
|
||||
self.stats.max_running_requests_under_SLO is not None
|
||||
and self.stats.max_running_requests_under_SLO > 0
|
||||
):
|
||||
self.stats.utilization = max(
|
||||
self.stats.num_running_reqs
|
||||
/ self.stats.max_running_requests_under_SLO,
|
||||
self.stats.token_usage / 0.9,
|
||||
)
|
||||
|
||||
@@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
req.time_stats.completion_time = time.time()
|
||||
req.time_stats.completion_time = time.perf_counter()
|
||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||
# This updates radix so others can match
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
@@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin:
|
||||
else:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
req.time_stats.completion_time = time.time()
|
||||
req.time_stats.completion_time = time.perf_counter()
|
||||
|
||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||
# speculative worker handles logprob in speculative decoding
|
||||
@@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin:
|
||||
and self.tp_rank == 0
|
||||
and self.server_args.enable_request_time_stats_logging
|
||||
):
|
||||
print(f"{req.finished_reason=}")
|
||||
req.log_time_stats()
|
||||
|
||||
# Send to detokenizer
|
||||
|
||||
@@ -5,6 +5,7 @@ import copy
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -24,6 +25,7 @@ import zmq
|
||||
from sglang.srt.managers.io_struct import (
|
||||
ClearHiCacheReqInput,
|
||||
ClearHiCacheReqOutput,
|
||||
CloseSessionReqInput,
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
DestroyWeightsUpdateGroupReqOutput,
|
||||
ExpertDistributionReq,
|
||||
@@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import (
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
ProfileReq,
|
||||
ProfileReqOutput,
|
||||
ProfileReqType,
|
||||
@@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin:
|
||||
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
|
||||
req = GetLoadReqInput()
|
||||
return await self.get_load_communicator(req)
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
if obj.session_id is None:
|
||||
obj.session_id = uuid.uuid4().hex
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[obj.session_id]
|
||||
del self.session_futures[obj.session_id]
|
||||
return session_id
|
||||
|
||||
async def close_session(
|
||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
out_skip_names = None
|
||||
if self.log_requests:
|
||||
if self.log_requests_level == 0:
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
"sampling_params",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
"embedding",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
"embedding",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 2:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 3:
|
||||
max_length = 1 << 30
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
||||
)
|
||||
return max_length, skip_names, out_skip_names
|
||||
|
||||
@@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
else None
|
||||
)
|
||||
self.crash_dump_folder = server_args.crash_dump_folder
|
||||
self.enable_trace = server_args.enable_trace
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
@@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
# If it's a single value, add worker_id prefix
|
||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||
|
||||
if obj.is_single:
|
||||
bootstrap_room = (
|
||||
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
||||
)
|
||||
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
||||
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
||||
else:
|
||||
for i in range(len(obj.rid)):
|
||||
bootstrap_room = (
|
||||
obj.bootstrap_room[i]
|
||||
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
||||
else None
|
||||
)
|
||||
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
||||
trace_slice_start(
|
||||
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
||||
)
|
||||
if self.enable_trace:
|
||||
self._trace_request_start(obj, created_time)
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
@@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
req = AbortReq(rid, abort_all)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector.observe_one_aborted_request()
|
||||
# TODO: also use custom_labels from the request
|
||||
self.metrics_collector.observe_one_aborted_request(
|
||||
self.metrics_collector.labels
|
||||
)
|
||||
|
||||
async def pause_generation(self):
|
||||
async with self.is_pause_cond:
|
||||
@@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
all_paused_requests = [r.num_paused_requests for r in result]
|
||||
return all_success, all_message, all_paused_requests
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
if obj.session_id is None:
|
||||
obj.session_id = uuid.uuid4().hex
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[obj.session_id]
|
||||
del self.session_futures[obj.session_id]
|
||||
return session_id
|
||||
|
||||
async def close_session(
|
||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
out_skip_names = None
|
||||
if self.log_requests:
|
||||
if self.log_requests_level == 0:
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
"sampling_params",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
"embedding",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
"embedding",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 2:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 3:
|
||||
max_length = 1 << 30
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
||||
)
|
||||
return max_length, skip_names, out_skip_names
|
||||
|
||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||
if obj.log_requests is not None:
|
||||
self.log_requests = obj.log_requests
|
||||
@@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
# Drain requests
|
||||
while True:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
remaining_rids = list(self.rid_to_state.keys())
|
||||
|
||||
if self.server_status == ServerStatus.UnHealthy:
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
remain_num_req,
|
||||
"Signal SIGTERM received while health check failed. Force exiting."
|
||||
)
|
||||
self.dump_requests_before_crash()
|
||||
break
|
||||
@@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
||||
# if force shutdown flag set, exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
|
||||
remain_num_req,
|
||||
"Signal SIGTERM received while force shutdown flag set. Force exiting."
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
||||
f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
|
||||
)
|
||||
if remain_num_req > 0:
|
||||
await asyncio.sleep(5)
|
||||
@@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
||||
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
||||
|
||||
def _trace_request_start(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
created_time: Optional[float] = None,
|
||||
):
|
||||
if obj.is_single:
|
||||
bootstrap_room = (
|
||||
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
||||
)
|
||||
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
||||
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
||||
else:
|
||||
for i in range(len(obj.rid)):
|
||||
bootstrap_room = (
|
||||
obj.bootstrap_room[i]
|
||||
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
||||
else None
|
||||
)
|
||||
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
||||
trace_slice_start(
|
||||
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
||||
)
|
||||
|
||||
|
||||
class ServerStatus(Enum):
|
||||
Up = "Up"
|
||||
@@ -1933,7 +1866,7 @@ class SignalHandler:
|
||||
|
||||
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||
logger.error(
|
||||
"Received sigquit from a child process. It usually means the child failed."
|
||||
f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
|
||||
)
|
||||
self.tokenizer_manager.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
"""Utilities for Prometheus Metrics Collection."""
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
@@ -34,6 +34,7 @@ class TimeStats:
|
||||
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
|
||||
"""
|
||||
|
||||
disagg_mode: DisaggregationMode = DisaggregationMode.NULL
|
||||
lb_entry_time: float = 0.0
|
||||
wait_queue_entry_time: float = 0.0
|
||||
forward_entry_time: float = 0.0
|
||||
@@ -43,20 +44,11 @@ class TimeStats:
|
||||
decode_prealloc_queue_entry_time: float = 0.0
|
||||
decode_transfer_queue_entry_time: float = 0.0
|
||||
|
||||
class RequestType(Enum):
|
||||
UNIFIED = "unified"
|
||||
PREFILL = "prefill"
|
||||
DECODE = "decode"
|
||||
INVALID = "invalid"
|
||||
|
||||
def get_queueing_time(self) -> float:
|
||||
return self.forward_entry_time - self.wait_queue_entry_time
|
||||
|
||||
def __str__(self) -> str:
|
||||
# if unified
|
||||
_type = self.get_type()
|
||||
|
||||
if _type == self.RequestType.UNIFIED:
|
||||
def convert_to_duration(self) -> str:
|
||||
if self.disagg_mode == DisaggregationMode.NULL:
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
@@ -65,30 +57,28 @@ class TimeStats:
|
||||
queue_duration >= 0 and forward_duration >= 0
|
||||
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
|
||||
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
|
||||
elif _type == self.RequestType.PREFILL:
|
||||
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}"
|
||||
elif self.disagg_mode == DisaggregationMode.PREFILL:
|
||||
bootstrap_duration = (
|
||||
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
|
||||
)
|
||||
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
bootstrap_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
|
||||
# if decode
|
||||
elif _type == self.RequestType.DECODE:
|
||||
if self.wait_queue_entry_time > 0:
|
||||
assert (
|
||||
bootstrap_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
|
||||
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}"
|
||||
elif self.disagg_mode == DisaggregationMode.DECODE:
|
||||
prealloc_duration = (
|
||||
self.decode_transfer_queue_entry_time
|
||||
- self.decode_prealloc_queue_entry_time
|
||||
)
|
||||
|
||||
transfer_duration = (
|
||||
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
|
||||
)
|
||||
@@ -96,42 +86,30 @@ class TimeStats:
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
prealloc_duration >= 0
|
||||
and transfer_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
if self.wait_queue_entry_time > 0:
|
||||
assert (
|
||||
prealloc_duration >= 0
|
||||
and transfer_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}"
|
||||
|
||||
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
|
||||
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}"
|
||||
else:
|
||||
return "Invalid Time Stats"
|
||||
return "Unknown Time Stats"
|
||||
|
||||
def format_duration(self, duration: float) -> str:
|
||||
return f"{duration * 1e3:.2f}ms"
|
||||
|
||||
def get_type(self) -> RequestType:
|
||||
"""Determine the type of request based on timestamp values."""
|
||||
if (
|
||||
self.prefill_bootstrap_queue_entry_time == 0.0
|
||||
and self.prefill_transfer_queue_entry_time == 0.0
|
||||
and self.decode_prealloc_queue_entry_time == 0.0
|
||||
and self.decode_transfer_queue_entry_time == 0.0
|
||||
):
|
||||
return self.RequestType.UNIFIED
|
||||
elif (
|
||||
self.prefill_bootstrap_queue_entry_time > 0.0
|
||||
and self.prefill_transfer_queue_entry_time > 0.0
|
||||
):
|
||||
return self.RequestType.PREFILL
|
||||
elif (
|
||||
self.decode_prealloc_queue_entry_time > 0.0
|
||||
and self.decode_transfer_queue_entry_time > 0.0
|
||||
and self.wait_queue_entry_time > 0.0
|
||||
):
|
||||
return self.RequestType.DECODE
|
||||
def disagg_mode_str(self) -> str:
|
||||
if self.disagg_mode == DisaggregationMode.NULL:
|
||||
return "unified"
|
||||
elif self.disagg_mode == DisaggregationMode.DECODE:
|
||||
return "decode"
|
||||
elif self.disagg_mode == DisaggregationMode.PREFILL:
|
||||
return "prefill"
|
||||
else:
|
||||
return self.RequestType.INVALID
|
||||
return "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -145,12 +123,15 @@ class SchedulerStats:
|
||||
num_queue_reqs: int = 0
|
||||
num_grammar_queue_reqs: int = 0
|
||||
num_running_reqs_offline_batch: int = 0
|
||||
avg_request_queue_latency: float = 0.0
|
||||
cache_hit_rate: float = 0.0
|
||||
|
||||
# Speculative decoding
|
||||
spec_accept_length: float = 0.0
|
||||
|
||||
# Retract
|
||||
num_retracted_reqs: int = 0
|
||||
num_paused_reqs: int = 0
|
||||
|
||||
# PD disaggregation
|
||||
num_prefill_prealloc_queue_reqs: int = 0
|
||||
num_prefill_inflight_queue_reqs: int = 0
|
||||
@@ -159,11 +140,6 @@ class SchedulerStats:
|
||||
kv_transfer_speed_gb_s: float = 0.0
|
||||
kv_transfer_latency_ms: float = 0.0
|
||||
|
||||
# Retract
|
||||
total_retracted_reqs: int = 0
|
||||
num_retracted_reqs: int = 0
|
||||
num_paused_reqs: int = 0
|
||||
|
||||
# Utilization
|
||||
utilization: float = 0.0
|
||||
max_running_requests_under_SLO: Optional[int] = None
|
||||
@@ -230,12 +206,6 @@ class SchedulerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
self.avg_request_queue_latency = Gauge(
|
||||
name="sglang:avg_request_queue_latency",
|
||||
documentation="The average request queue latency for the last batch of requests in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
self.cache_hit_rate = Gauge(
|
||||
name="sglang:cache_hit_rate",
|
||||
documentation="The prefix cache hit rate.",
|
||||
@@ -251,6 +221,18 @@ class SchedulerMetricsCollector:
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
# Retract
|
||||
self.num_retracted_reqs = Gauge(
|
||||
name="sglang:num_retracted_reqs",
|
||||
documentation="The number of retracted requests.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
self.num_paused_reqs = Gauge(
|
||||
name="sglang:num_paused_reqs",
|
||||
documentation="The number of paused requests by async weight sync.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
# PD disaggregation
|
||||
self.num_prefill_prealloc_queue_reqs = Gauge(
|
||||
name="sglang:num_prefill_prealloc_queue_reqs",
|
||||
@@ -299,24 +281,6 @@ class SchedulerMetricsCollector:
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
# Retract
|
||||
self.total_retracted_reqs = Gauge(
|
||||
name="sglang:total_retracted_reqs",
|
||||
documentation="The total number of retracted requests due to kvcache full.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
self.num_retracted_reqs = Gauge(
|
||||
name="sglang:num_retracted_reqs",
|
||||
documentation="The number of retracted requests.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
self.num_paused_reqs = Gauge(
|
||||
name="sglang:num_paused_reqs",
|
||||
documentation="The number of paused requests by async weight sync.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
# Utilization
|
||||
self.utilization = Gauge(
|
||||
name="sglang:utilization",
|
||||
@@ -347,7 +311,7 @@ class SchedulerMetricsCollector:
|
||||
|
||||
# Additional queueing time histogram
|
||||
self.queue_time = Histogram(
|
||||
name="sglang:queue_time_s",
|
||||
name="sglang:queue_time_seconds",
|
||||
documentation="Histogram of queueing time in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
@@ -513,8 +477,8 @@ class SchedulerMetricsCollector:
|
||||
buckets=tree_traversal_time_buckets,
|
||||
)
|
||||
|
||||
self.request_latency_seconds = Histogram(
|
||||
name="sglang:request_latency_seconds",
|
||||
self.per_stage_req_latency_seconds = Histogram(
|
||||
name="sglang:per_stage_req_latency_seconds",
|
||||
documentation="The latency of each stage of requests.",
|
||||
# captures latency in range [1ms - ~1191s]
|
||||
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
|
||||
@@ -525,7 +489,7 @@ class SchedulerMetricsCollector:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
histogram.labels(**self.labels).observe(data)
|
||||
|
||||
def increment_bootstrap_failed_reqs(self) -> None:
|
||||
@@ -534,9 +498,12 @@ class SchedulerMetricsCollector:
|
||||
def increment_transfer_failed_reqs(self) -> None:
|
||||
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
||||
|
||||
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
|
||||
def observe_per_stage_req_latency(self, stage: str, latency: float) -> None:
|
||||
labels_with_stage = {**self.labels, "stage": stage}
|
||||
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
|
||||
self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency)
|
||||
|
||||
def observe_queue_time(self, latency: float) -> None:
|
||||
self._log_histogram(self.queue_time, latency)
|
||||
|
||||
def log_stats(self, stats: SchedulerStats) -> None:
|
||||
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||
@@ -550,7 +517,6 @@ class SchedulerMetricsCollector:
|
||||
self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
|
||||
)
|
||||
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
|
||||
|
||||
# Speculative decoding
|
||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||
@@ -572,7 +538,6 @@ class SchedulerMetricsCollector:
|
||||
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
|
||||
|
||||
# Retract
|
||||
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
|
||||
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
|
||||
self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
|
||||
|
||||
@@ -596,19 +561,19 @@ class SchedulerMetricsCollector:
|
||||
def log_grammar_stats(self, grammar_stats) -> None:
|
||||
# Duck-typed GrammarStats to avoid cross-package dependency
|
||||
if getattr(grammar_stats, "compilation_time", None) is not None:
|
||||
self.log_histogram(
|
||||
self._log_histogram(
|
||||
self.grammar_compilation_time, grammar_stats.compilation_time
|
||||
)
|
||||
if getattr(grammar_stats, "schema_count", None) is not None:
|
||||
self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
|
||||
self._log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
|
||||
if getattr(grammar_stats, "ebnf_size", None) is not None:
|
||||
self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
|
||||
self._log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
|
||||
tree_times = getattr(grammar_stats, "tree_traversal_time", None)
|
||||
if tree_times:
|
||||
max_time = max(tree_times)
|
||||
avg_time = sum(tree_times) / len(tree_times)
|
||||
self.log_histogram(self.grammar_tree_traversal_time_max, max_time)
|
||||
self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
|
||||
self._log_histogram(self.grammar_tree_traversal_time_max, max_time)
|
||||
self._log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
|
||||
if getattr(grammar_stats, "is_cache_hit", False):
|
||||
self.num_grammar_cache_hit.labels(**self.labels).inc(1)
|
||||
if getattr(grammar_stats, "is_grammar_aborted", False):
|
||||
@@ -714,7 +679,7 @@ class TokenizerMetricsCollector:
|
||||
)
|
||||
|
||||
self.num_aborted_requests_total = Counter(
|
||||
name="sglang:num_aborted_requests",
|
||||
name="sglang:num_aborted_requests_total",
|
||||
documentation="Number of requests aborted.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
@@ -801,7 +766,7 @@ class TokenizerMetricsCollector:
|
||||
buckets=bucket_time_to_first_token,
|
||||
)
|
||||
|
||||
self.histogram_inter_token_latency_seconds = Histogram(
|
||||
self.histogram_inter_token_latency = Histogram(
|
||||
name="sglang:inter_token_latency_seconds",
|
||||
documentation="Histogram of inter-token latency in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
@@ -815,14 +780,6 @@ class TokenizerMetricsCollector:
|
||||
buckets=bucket_e2e_request_latency,
|
||||
)
|
||||
|
||||
# Offline batch specific TTFB histogram
|
||||
self.histogram_time_to_first_token_offline_batch = Histogram(
|
||||
name="sglang:time_to_first_token_seconds_offline_batch",
|
||||
documentation="Histogram of time to first token in seconds for offline batch requests.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_time_to_first_token,
|
||||
)
|
||||
|
||||
def observe_one_finished_request(
|
||||
self,
|
||||
labels: Dict[str, str],
|
||||
@@ -846,15 +803,8 @@ class TokenizerMetricsCollector:
|
||||
float(generation_tokens)
|
||||
)
|
||||
|
||||
def observe_time_to_first_token(
|
||||
self, labels: Dict[str, str], value: float, type: str = ""
|
||||
):
|
||||
if type == "batch":
|
||||
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
|
||||
value
|
||||
)
|
||||
else:
|
||||
self.histogram_time_to_first_token.labels(**labels).observe(value)
|
||||
def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
|
||||
self.histogram_time_to_first_token.labels(**labels).observe(value)
|
||||
|
||||
def check_time_to_first_token_straggler(self, value: float) -> bool:
|
||||
his = self.histogram_time_to_first_token.labels(**self.labels)
|
||||
@@ -876,7 +826,7 @@ class TokenizerMetricsCollector:
|
||||
|
||||
# A faster version of the Histogram::observe which observes multiple values at the same time.
|
||||
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
|
||||
his = self.histogram_inter_token_latency_seconds.labels(**labels)
|
||||
his = self.histogram_inter_token_latency.labels(**labels)
|
||||
his._sum.inc(internval)
|
||||
|
||||
for i, bound in enumerate(his._upper_bounds):
|
||||
@@ -884,8 +834,8 @@ class TokenizerMetricsCollector:
|
||||
his._buckets[i].inc(num_new_tokens)
|
||||
break
|
||||
|
||||
def observe_one_aborted_request(self):
|
||||
self.num_aborted_requests_total.labels(**self.labels).inc(1)
|
||||
def observe_one_aborted_request(self, labels: Dict[str, str]):
|
||||
self.num_aborted_requests_total.labels(**labels).inc(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -23,7 +22,10 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import Req
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
opentelemetry_imported = False
|
||||
@@ -407,9 +409,11 @@ def trace_slice_start(
|
||||
ts: Optional[int] = None,
|
||||
anonymous: bool = False,
|
||||
):
|
||||
if not tracing_enabled:
|
||||
return
|
||||
|
||||
rid = str(rid)
|
||||
if not tracing_enabled or rid not in reqs_context:
|
||||
if rid not in reqs_context:
|
||||
return
|
||||
|
||||
pid = threading.get_native_id()
|
||||
@@ -458,8 +462,11 @@ def trace_slice_end(
|
||||
auto_next_anon: bool = False,
|
||||
thread_finish_flag: bool = False,
|
||||
):
|
||||
if not tracing_enabled:
|
||||
return
|
||||
|
||||
rid = str(rid)
|
||||
if not tracing_enabled or rid not in reqs_context:
|
||||
if rid not in reqs_context:
|
||||
return
|
||||
|
||||
pid = threading.get_native_id()
|
||||
@@ -512,10 +519,13 @@ trace_slice = trace_slice_end
|
||||
|
||||
# Add event to the current slice on the same thread with the same rid.
|
||||
def trace_event(name: str, rid: str, ts: Optional[int] = None):
|
||||
if not tracing_enabled or rid not in reqs_context:
|
||||
if not tracing_enabled:
|
||||
return
|
||||
|
||||
rid = str(rid)
|
||||
if rid not in reqs_context:
|
||||
return
|
||||
|
||||
pid = threading.get_native_id()
|
||||
if pid not in reqs_context[rid].threads_context:
|
||||
return
|
||||
@@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
|
||||
|
||||
# Add attrs to the current slice on the same thread with the same rid.
|
||||
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
|
||||
if not tracing_enabled or rid not in reqs_context:
|
||||
if not tracing_enabled:
|
||||
return
|
||||
|
||||
rid = str(rid)
|
||||
if rid not in reqs_context:
|
||||
return
|
||||
|
||||
pid = threading.get_native_id()
|
||||
if pid not in reqs_context[rid].threads_context:
|
||||
return
|
||||
@@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
|
||||
|
||||
slice_info = thread_context.cur_slice_stack[-1]
|
||||
slice_info.span.set_attributes(attrs)
|
||||
|
||||
|
||||
def trace_slice_batch(
|
||||
name: str,
|
||||
reqs: List[Req],
|
||||
):
|
||||
for req in reqs:
|
||||
trace_slice(
|
||||
name,
|
||||
req.rid,
|
||||
auto_next_anon=not req.finished(),
|
||||
thread_finish_flag=req.finished(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user