Fix metrics and request tracing (TimeStats) (#11123)
This commit is contained in:
@@ -14,18 +14,17 @@ classifiers = [
|
|||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
|
||||||
"requests",
|
|
||||||
"tqdm",
|
|
||||||
"numpy",
|
|
||||||
"IPython",
|
"IPython",
|
||||||
"setproctitle",
|
"aiohttp",
|
||||||
|
"anthropic>=0.20.0",
|
||||||
"blobfile==3.0.0",
|
"blobfile==3.0.0",
|
||||||
"build",
|
"build",
|
||||||
"compressed-tensors",
|
"compressed-tensors",
|
||||||
|
"cuda-python",
|
||||||
"datasets",
|
"datasets",
|
||||||
"einops",
|
"einops",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
|
"flashinfer_python==0.4.0rc3",
|
||||||
"hf_transfer",
|
"hf_transfer",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"interegular",
|
"interegular",
|
||||||
@@ -33,8 +32,10 @@ dependencies = [
|
|||||||
"modelscope",
|
"modelscope",
|
||||||
"msgspec",
|
"msgspec",
|
||||||
"ninja",
|
"ninja",
|
||||||
"openai==1.99.1",
|
"numpy",
|
||||||
|
"nvidia-cutlass-dsl==4.2.1",
|
||||||
"openai-harmony==0.0.4",
|
"openai-harmony==0.0.4",
|
||||||
|
"openai==1.99.1",
|
||||||
"orjson",
|
"orjson",
|
||||||
"outlines==0.1.11",
|
"outlines==0.1.11",
|
||||||
"packaging",
|
"packaging",
|
||||||
@@ -42,32 +43,30 @@ dependencies = [
|
|||||||
"pillow",
|
"pillow",
|
||||||
"prometheus-client>=0.20.0",
|
"prometheus-client>=0.20.0",
|
||||||
"psutil",
|
"psutil",
|
||||||
|
"py-spy",
|
||||||
"pybase64",
|
"pybase64",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"pynvml",
|
"pynvml",
|
||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pyzmq>=25.1.2",
|
"pyzmq>=25.1.2",
|
||||||
|
"requests",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
"setproctitle",
|
||||||
|
"sgl-kernel==0.3.13",
|
||||||
"soundfile==0.13.1",
|
"soundfile==0.13.1",
|
||||||
"timm==1.0.16",
|
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
|
"timm==1.0.16",
|
||||||
|
"torch==2.8.0",
|
||||||
|
"torch_memory_saver==0.0.8",
|
||||||
"torchao==0.9.0",
|
"torchao==0.9.0",
|
||||||
|
"torchaudio==2.8.0",
|
||||||
|
"torchvision",
|
||||||
|
"tqdm",
|
||||||
"transformers==4.56.1",
|
"transformers==4.56.1",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"uvloop",
|
"uvloop",
|
||||||
"xgrammar==0.1.24",
|
"xgrammar==0.1.24"
|
||||||
"sgl-kernel==0.3.13",
|
|
||||||
"torch==2.8.0",
|
|
||||||
"torchaudio==2.8.0",
|
|
||||||
"torchvision",
|
|
||||||
"cuda-python",
|
|
||||||
"flashinfer_python==0.4.0rc3",
|
|
||||||
"openai==1.99.1",
|
|
||||||
"tiktoken",
|
|
||||||
"anthropic>=0.20.0",
|
|
||||||
"torch_memory_saver==0.0.8",
|
|
||||||
"nvidia-cutlass-dsl==4.2.1",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -79,15 +78,15 @@ test = [
|
|||||||
"matplotlib",
|
"matplotlib",
|
||||||
"pandas",
|
"pandas",
|
||||||
"peft",
|
"peft",
|
||||||
"sentence_transformers",
|
|
||||||
"pytest",
|
"pytest",
|
||||||
|
"sentence_transformers",
|
||||||
"tabulate",
|
"tabulate",
|
||||||
]
|
]
|
||||||
tracing = [
|
tracing = [
|
||||||
"opentelemetry-sdk",
|
|
||||||
"opentelemetry-api",
|
"opentelemetry-api",
|
||||||
"opentelemetry-exporter-otlp",
|
"opentelemetry-exporter-otlp",
|
||||||
"opentelemetry-exporter-otlp-proto-grpc",
|
"opentelemetry-exporter-otlp-proto-grpc",
|
||||||
|
"opentelemetry-sdk",
|
||||||
]
|
]
|
||||||
all = ["sglang[test]", "sglang[decord]"]
|
all = ["sglang[test]", "sglang[decord]"]
|
||||||
blackwell = ["sglang[test]", "sglang[decord]"]
|
blackwell = ["sglang[test]", "sglang[decord]"]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@@ -422,9 +423,13 @@ class DecodePreallocQueue:
|
|||||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||||
)
|
)
|
||||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||||
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
|
||||||
preallocated_reqs.append(decode_req)
|
preallocated_reqs.append(decode_req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
||||||
|
time.perf_counter()
|
||||||
|
)
|
||||||
|
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||||
|
|
||||||
self.queue = [
|
self.queue = [
|
||||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||||
@@ -625,6 +630,7 @@ class DecodeTransferQueue:
|
|||||||
decode_req.req.output_topk_p = output_topk_p
|
decode_req.req.output_topk_p = output_topk_p
|
||||||
decode_req.req.output_topk_index = output_topk_index
|
decode_req.req.output_topk_index = output_topk_index
|
||||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||||
|
|
||||||
if decode_req.req.return_logprob:
|
if decode_req.req.return_logprob:
|
||||||
decode_req.req.output_token_logprobs_val.append(
|
decode_req.req.output_token_logprobs_val.append(
|
||||||
output_token_logprobs_val[0].item()
|
output_token_logprobs_val[0].item()
|
||||||
@@ -645,10 +651,17 @@ class DecodeTransferQueue:
|
|||||||
|
|
||||||
if hasattr(decode_req.kv_receiver, "clear"):
|
if hasattr(decode_req.kv_receiver, "clear"):
|
||||||
decode_req.kv_receiver.clear()
|
decode_req.kv_receiver.clear()
|
||||||
|
decode_req.kv_receiver = None
|
||||||
|
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
|
||||||
|
|
||||||
# special handling for sampling_params.max_new_tokens == 1
|
# special handling for sampling_params.max_new_tokens == 1
|
||||||
if decode_req.req.sampling_params.max_new_tokens == 1:
|
if decode_req.req.sampling_params.max_new_tokens == 1:
|
||||||
# finish immediately
|
# finish immediately
|
||||||
|
decode_req.req.time_stats.forward_entry_time = (
|
||||||
|
decode_req.req.time_stats.completion_time
|
||||||
|
) = time.perf_counter()
|
||||||
decode_req.req.check_finished()
|
decode_req.req.check_finished()
|
||||||
self.scheduler.stream_output(
|
self.scheduler.stream_output(
|
||||||
[decode_req.req], decode_req.req.return_logprob
|
[decode_req.req], decode_req.req.return_logprob
|
||||||
@@ -656,8 +669,6 @@ class DecodeTransferQueue:
|
|||||||
self.tree_cache.cache_finished_req(decode_req.req)
|
self.tree_cache.cache_finished_req(decode_req.req)
|
||||||
else:
|
else:
|
||||||
transferred_reqs.append(decode_req.req)
|
transferred_reqs.append(decode_req.req)
|
||||||
|
|
||||||
indices_to_remove.add(i)
|
|
||||||
elif poll in [
|
elif poll in [
|
||||||
KVPoll.Bootstrapping,
|
KVPoll.Bootstrapping,
|
||||||
KVPoll.WaitingForInput,
|
KVPoll.WaitingForInput,
|
||||||
@@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
if len(can_run_list) == 0:
|
if len(can_run_list) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
for req in can_run_list:
|
||||||
|
req.time_stats.forward_entry_time = time.perf_counter()
|
||||||
|
|
||||||
# construct a schedule batch with those requests and mark as decode
|
# construct a schedule batch with those requests and mark as decode
|
||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
can_run_list,
|
can_run_list,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Type
|
from typing import TYPE_CHECKING, List, Optional, Type
|
||||||
@@ -263,9 +264,10 @@ class PrefillBootstrapQueue:
|
|||||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||||
|
|
||||||
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
|
||||||
bootstrapped_reqs.append(req)
|
bootstrapped_reqs.append(req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
||||||
|
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
||||||
|
|
||||||
self.queue = [
|
self.queue = [
|
||||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||||
@@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
for i, (req, next_token_id) in enumerate(
|
for i, (req, next_token_id) in enumerate(
|
||||||
zip(batch.reqs, next_token_ids, strict=True)
|
zip(batch.reqs, next_token_ids, strict=True)
|
||||||
):
|
):
|
||||||
req: Req
|
|
||||||
if req.is_chunked <= 0:
|
if req.is_chunked <= 0:
|
||||||
# There is no output_ids for prefill
|
# There is no output_ids for prefill
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
@@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
)
|
)
|
||||||
logprob_pt += num_input_logprobs
|
logprob_pt += num_input_logprobs
|
||||||
self.send_kv_chunk(req, last_chunk=True)
|
self.send_kv_chunk(req, last_chunk=True)
|
||||||
|
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
||||||
@@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
else:
|
else:
|
||||||
assert False, f"Unexpected polling state {poll=}"
|
assert False, f"Unexpected polling state {poll=}"
|
||||||
|
|
||||||
|
for req in done_reqs:
|
||||||
|
req.time_stats.completion_time = time.perf_counter()
|
||||||
|
|
||||||
# Stream requests which have finished transfer
|
# Stream requests which have finished transfer
|
||||||
self.stream_output(
|
self.stream_output(
|
||||||
done_reqs,
|
done_reqs,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import random
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ import time
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
|||||||
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||||
ScheduleBatchDisaggregationDecodeMixin,
|
ScheduleBatchDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.mem_cache.allocator import (
|
from sglang.srt.mem_cache.allocator import (
|
||||||
BaseTokenToKVPoolAllocator,
|
BaseTokenToKVPoolAllocator,
|
||||||
@@ -452,6 +453,7 @@ class Req:
|
|||||||
bootstrap_host: Optional[str] = None,
|
bootstrap_host: Optional[str] = None,
|
||||||
bootstrap_port: Optional[int] = None,
|
bootstrap_port: Optional[int] = None,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
|
disagg_mode: Optional[DisaggregationMode] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
data_parallel_rank: Optional[int] = None,
|
||||||
vocab_size: Optional[int] = None,
|
vocab_size: Optional[int] = None,
|
||||||
priority: Optional[int] = None,
|
priority: Optional[int] = None,
|
||||||
@@ -628,10 +630,8 @@ class Req:
|
|||||||
|
|
||||||
# For metrics
|
# For metrics
|
||||||
self.metrics_collector = metrics_collector
|
self.metrics_collector = metrics_collector
|
||||||
self.time_stats: TimeStats = TimeStats()
|
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
||||||
self.has_log_time_stats: bool = False
|
self.has_log_time_stats: bool = False
|
||||||
self.queue_time_start = None
|
|
||||||
self.queue_time_end = None
|
|
||||||
self.last_tic = time.monotonic()
|
self.last_tic = time.monotonic()
|
||||||
|
|
||||||
# For disaggregation
|
# For disaggregation
|
||||||
@@ -668,9 +668,9 @@ class Req:
|
|||||||
def add_latency(self, stage: RequestStage):
|
def add_latency(self, stage: RequestStage):
|
||||||
if self.metrics_collector is None:
|
if self.metrics_collector is None:
|
||||||
return
|
return
|
||||||
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
|
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
self.metrics_collector.observe_request_latency_seconds(
|
self.metrics_collector.observe_per_stage_req_latency(
|
||||||
stage.value, now - self.last_tic
|
stage.value, now - self.last_tic
|
||||||
)
|
)
|
||||||
self.last_tic = now
|
self.last_tic = now
|
||||||
@@ -834,10 +834,10 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.bootstrap_room is not None:
|
if self.bootstrap_room is not None:
|
||||||
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
||||||
else:
|
else:
|
||||||
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
||||||
logger.info(f"{prefix}: {self.time_stats}")
|
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
|
||||||
self.has_log_time_stats = True
|
self.has_log_time_stats = True
|
||||||
|
|
||||||
def set_finish_with_abort(self, error_msg: str):
|
def set_finish_with_abort(self, error_msg: str):
|
||||||
@@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
) / total_max_new_tokens
|
) / total_max_new_tokens
|
||||||
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||||
|
|
||||||
return retracted_reqs, new_estimate_ratio
|
return retracted_reqs, new_estimate_ratio, []
|
||||||
|
|
||||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
|
|||||||
@@ -276,9 +276,13 @@ class SchedulePolicy:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Sorts the waiting queue based on the request priority then received titmestamp."""
|
"""Sorts the waiting queue based on the request priority then received titmestamp."""
|
||||||
if schedule_low_priority_values_first:
|
if schedule_low_priority_values_first:
|
||||||
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
|
waiting_queue.sort(
|
||||||
|
key=lambda x: (x.priority, x.time_stats.wait_queue_entry_time)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
|
waiting_queue.sort(
|
||||||
|
key=lambda x: (-x.priority, x.time_stats.wait_queue_entry_time)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
||||||
@@ -642,12 +646,12 @@ class PrefillAdder:
|
|||||||
if server_args.schedule_low_priority_values_first:
|
if server_args.schedule_low_priority_values_first:
|
||||||
sorted_running_reqs = sorted(
|
sorted_running_reqs = sorted(
|
||||||
self.running_batch.reqs,
|
self.running_batch.reqs,
|
||||||
key=lambda x: (-x.priority, -x.queue_time_start),
|
key=lambda x: (-x.priority, -x.time_stats.wait_queue_entry_time),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sorted_running_reqs = sorted(
|
sorted_running_reqs = sorted(
|
||||||
self.running_batch.reqs,
|
self.running_batch.reqs,
|
||||||
key=lambda x: (x.priority, -x.queue_time_start),
|
key=lambda x: (x.priority, -x.time_stats.wait_queue_entry_time),
|
||||||
)
|
)
|
||||||
preemptible_reqs = []
|
preemptible_reqs = []
|
||||||
min_tokens_to_remove = (
|
min_tokens_to_remove = (
|
||||||
|
|||||||
@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.tracing.trace import (
|
from sglang.srt.tracing.trace import (
|
||||||
process_tracing_init,
|
process_tracing_init,
|
||||||
trace_event,
|
|
||||||
trace_set_proc_propagate_context,
|
trace_set_proc_propagate_context,
|
||||||
trace_set_thread_info,
|
trace_set_thread_info,
|
||||||
trace_slice,
|
trace_slice_batch,
|
||||||
trace_slice_end,
|
trace_slice_end,
|
||||||
trace_slice_start,
|
trace_slice_start,
|
||||||
)
|
)
|
||||||
@@ -263,6 +262,7 @@ class Scheduler(
|
|||||||
server_args.enable_metrics_for_all_schedulers
|
server_args.enable_metrics_for_all_schedulers
|
||||||
)
|
)
|
||||||
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
|
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
|
||||||
|
self.enable_trace = server_args.enable_trace
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
@@ -899,10 +899,6 @@ class Scheduler(
|
|||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
|
||||||
for req in batch.reqs:
|
|
||||||
trace_event("schedule", req.rid)
|
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
@@ -924,10 +920,6 @@ class Scheduler(
|
|||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
|
||||||
for req in batch.reqs:
|
|
||||||
trace_event("schedule", req.rid)
|
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
batch.launch_done = threading.Event()
|
batch.launch_done = threading.Event()
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
@@ -1192,8 +1184,11 @@ class Scheduler(
|
|||||||
src=self.tp_group.ranks[0],
|
src=self.tp_group.ranks[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_trace:
|
||||||
for req in recv_reqs:
|
for req in recv_reqs:
|
||||||
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
|
if isinstance(
|
||||||
|
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||||
|
):
|
||||||
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
||||||
trace_slice_start("", req.rid, anonymous=True)
|
trace_slice_start("", req.rid, anonymous=True)
|
||||||
|
|
||||||
@@ -1277,6 +1272,7 @@ class Scheduler(
|
|||||||
bootstrap_host=recv_req.bootstrap_host,
|
bootstrap_host=recv_req.bootstrap_host,
|
||||||
bootstrap_port=recv_req.bootstrap_port,
|
bootstrap_port=recv_req.bootstrap_port,
|
||||||
bootstrap_room=recv_req.bootstrap_room,
|
bootstrap_room=recv_req.bootstrap_room,
|
||||||
|
disagg_mode=self.disaggregation_mode,
|
||||||
data_parallel_rank=recv_req.data_parallel_rank,
|
data_parallel_rank=recv_req.data_parallel_rank,
|
||||||
vocab_size=self.model_config.vocab_size,
|
vocab_size=self.model_config.vocab_size,
|
||||||
priority=recv_req.priority,
|
priority=recv_req.priority,
|
||||||
@@ -1403,7 +1399,6 @@ class Scheduler(
|
|||||||
req.set_finish_with_abort(error_msg)
|
req.set_finish_with_abort(error_msg)
|
||||||
|
|
||||||
if add_to_grammar_queue:
|
if add_to_grammar_queue:
|
||||||
req.queue_time_start = time.perf_counter()
|
|
||||||
self.grammar_queue.append(req)
|
self.grammar_queue.append(req)
|
||||||
else:
|
else:
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
@@ -1419,23 +1414,6 @@ class Scheduler(
|
|||||||
for tokenized_req in recv_req:
|
for tokenized_req in recv_req:
|
||||||
self.handle_generate_request(tokenized_req)
|
self.handle_generate_request(tokenized_req)
|
||||||
|
|
||||||
def _add_request_to_queue(self, req: Req):
|
|
||||||
req.queue_time_start = time.perf_counter()
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
||||||
self._prefetch_kvcache(req)
|
|
||||||
self.disagg_prefill_bootstrap_queue.add(
|
|
||||||
req, self.model_config.num_key_value_heads
|
|
||||||
)
|
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
||||||
self.disagg_decode_prealloc_queue.add(req)
|
|
||||||
else:
|
|
||||||
self._set_or_validate_priority(req)
|
|
||||||
if self._abort_on_queued_limit(req):
|
|
||||||
return
|
|
||||||
self._prefetch_kvcache(req)
|
|
||||||
self.waiting_queue.append(req)
|
|
||||||
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
|
||||||
|
|
||||||
def _prefetch_kvcache(self, req: Req):
|
def _prefetch_kvcache(self, req: Req):
|
||||||
if self.enable_hicache_storage:
|
if self.enable_hicache_storage:
|
||||||
req.init_next_round_input(self.tree_cache)
|
req.init_next_round_input(self.tree_cache)
|
||||||
@@ -1449,19 +1427,27 @@ class Scheduler(
|
|||||||
req.rid, req.last_host_node, new_input_tokens, last_hash
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.NULL:
|
||||||
self.disagg_prefill_bootstrap_queue.extend(
|
|
||||||
reqs, self.model_config.num_key_value_heads
|
|
||||||
)
|
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
||||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
|
||||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
|
||||||
else:
|
|
||||||
for req in reqs:
|
|
||||||
self._set_or_validate_priority(req)
|
self._set_or_validate_priority(req)
|
||||||
if not self._abort_on_queued_limit(req):
|
if self._abort_on_queued_limit(req):
|
||||||
|
return
|
||||||
|
self._prefetch_kvcache(req)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
||||||
|
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self._prefetch_kvcache(req)
|
||||||
|
self.disagg_prefill_bootstrap_queue.add(
|
||||||
|
req, self.model_config.num_key_value_heads
|
||||||
|
)
|
||||||
|
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
||||||
|
if not is_retracted:
|
||||||
|
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
||||||
|
|
||||||
def _set_or_validate_priority(self, req: Req):
|
def _set_or_validate_priority(self, req: Req):
|
||||||
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
||||||
@@ -1500,7 +1486,7 @@ class Scheduler(
|
|||||||
direction = 1 if self.schedule_low_priority_values_first else -1
|
direction = 1 if self.schedule_low_priority_values_first else -1
|
||||||
key_fn = lambda item: (
|
key_fn = lambda item: (
|
||||||
direction * item[1].priority,
|
direction * item[1].priority,
|
||||||
item[1].queue_time_start,
|
item[1].time_stats.wait_queue_entry_time,
|
||||||
)
|
)
|
||||||
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
||||||
abort_existing_req = (
|
abort_existing_req = (
|
||||||
@@ -1902,14 +1888,14 @@ class Scheduler(
|
|||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
# only record queue time when enable_metrics is True to avoid overhead
|
# only record queue time when enable_metrics is True to avoid overhead
|
||||||
for req in can_run_list:
|
for req in can_run_list:
|
||||||
req.queue_time_end = time.perf_counter()
|
|
||||||
req.add_latency(RequestStage.PREFILL_WAITING)
|
req.add_latency(RequestStage.PREFILL_WAITING)
|
||||||
|
|
||||||
self.waiting_queue = [
|
self.waiting_queue = [
|
||||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
]
|
]
|
||||||
if adder.preempt_list:
|
if adder.preempt_list:
|
||||||
self._extend_requests_to_queue(adder.preempt_list)
|
for req in adder.preempt_list:
|
||||||
|
self._add_request_to_queue(req)
|
||||||
|
|
||||||
if adder.new_chunked_req is not None:
|
if adder.new_chunked_req is not None:
|
||||||
assert self.chunked_req is None
|
assert self.chunked_req is None
|
||||||
@@ -1920,7 +1906,16 @@ class Scheduler(
|
|||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.current_scheduler_metrics_enabled():
|
if self.current_scheduler_metrics_enabled():
|
||||||
self.log_prefill_stats(adder, can_run_list, running_bs)
|
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
||||||
|
|
||||||
|
for req in can_run_list:
|
||||||
|
if req.time_stats.forward_entry_time == 0:
|
||||||
|
# Avoid update chunked request many times
|
||||||
|
req.time_stats.forward_entry_time = time.perf_counter()
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.metrics_collector.observe_queue_time(
|
||||||
|
req.time_stats.get_queueing_time(),
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new batch
|
# Create a new batch
|
||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
@@ -1975,19 +1970,25 @@ class Scheduler(
|
|||||||
TEST_RETRACT and batch.batch_size() > 10
|
TEST_RETRACT and batch.batch_size() > 10
|
||||||
):
|
):
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
|
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
||||||
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
self.server_args
|
||||||
num_retracted_reqs = len(retracted_reqs)
|
)
|
||||||
|
self.num_retracted_reqs = len(retracted_reqs)
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
|
for req in reqs_to_abort:
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
AbortReq(req.rid, abort_reason=req.to_abort_message)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"KV cache pool is full. Retract requests. "
|
"KV cache pool is full. Retract requests. "
|
||||||
f"#retracted_reqs: {num_retracted_reqs}, "
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
||||||
|
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
for req in retracted_reqs:
|
||||||
self.total_retracted_reqs += num_retracted_reqs
|
self._add_request_to_queue(req, is_retracted=True)
|
||||||
else:
|
else:
|
||||||
self.new_token_ratio = max(
|
self.new_token_ratio = max(
|
||||||
self.new_token_ratio - self.new_token_ratio_decay,
|
self.new_token_ratio - self.new_token_ratio_decay,
|
||||||
@@ -2086,23 +2087,14 @@ class Scheduler(
|
|||||||
):
|
):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result, launch_done)
|
self.process_batch_result_decode(batch, result, launch_done)
|
||||||
for req in batch.reqs:
|
if self.enable_trace:
|
||||||
trace_slice(
|
trace_slice_batch("decode loop", batch.reqs)
|
||||||
"decode loop",
|
|
||||||
req.rid,
|
|
||||||
auto_next_anon=not req.finished(),
|
|
||||||
thread_finish_flag=req.finished(),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif batch.forward_mode.is_extend():
|
elif batch.forward_mode.is_extend():
|
||||||
self.process_batch_result_prefill(batch, result, launch_done)
|
self.process_batch_result_prefill(batch, result, launch_done)
|
||||||
for req in batch.reqs:
|
if self.enable_trace:
|
||||||
trace_slice(
|
trace_slice_batch("prefill", batch.reqs)
|
||||||
"prefill",
|
|
||||||
req.rid,
|
|
||||||
auto_next_anon=not req.finished(),
|
|
||||||
thread_finish_flag=req.finished(),
|
|
||||||
)
|
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
@@ -2261,12 +2253,13 @@ class Scheduler(
|
|||||||
if req.finished(): # It is aborted by AbortReq
|
if req.finished(): # It is aborted by AbortReq
|
||||||
num_ready_reqs += 1
|
num_ready_reqs += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
req.grammar = req.grammar.result(timeout=0.03)
|
req.grammar = req.grammar.result(timeout=0.03)
|
||||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||||
if req.grammar is INVALID_GRAMMAR_OBJ:
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||||
req.set_finish_with_abort(
|
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
||||||
f"Invalid grammar request: {req.grammar_key=}"
|
req.set_finish_with_abort(error_msg)
|
||||||
)
|
|
||||||
num_ready_reqs += 1
|
num_ready_reqs += 1
|
||||||
except futures._base.TimeoutError:
|
except futures._base.TimeoutError:
|
||||||
req.grammar_wait_ct += 1
|
req.grammar_wait_ct += 1
|
||||||
@@ -2298,9 +2291,8 @@ class Scheduler(
|
|||||||
req.grammar = req.grammar.result()
|
req.grammar = req.grammar.result()
|
||||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||||
if req.grammar is INVALID_GRAMMAR_OBJ:
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||||
req.set_finish_with_abort(
|
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
||||||
f"Invalid grammar request: {req.grammar_key=}"
|
req.set_finish_with_abort(error_msg)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
num_ready_reqs_max = num_ready_reqs
|
num_ready_reqs_max = num_ready_reqs
|
||||||
num_timeout_reqs_max = num_timeout_reqs
|
num_timeout_reqs_max = num_timeout_reqs
|
||||||
@@ -2308,12 +2300,14 @@ class Scheduler(
|
|||||||
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
||||||
req = self.grammar_queue[i]
|
req = self.grammar_queue[i]
|
||||||
req.grammar.cancel()
|
req.grammar.cancel()
|
||||||
|
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
||||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||||
req.set_finish_with_abort(error_msg)
|
req.set_finish_with_abort(error_msg)
|
||||||
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
|
||||||
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
||||||
|
|
||||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
for req in self.grammar_queue[:num_ready_reqs]:
|
||||||
|
self._add_request_to_queue(req)
|
||||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
|
|
||||||
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
||||||
@@ -2795,17 +2789,11 @@ def run_scheduler_process(
|
|||||||
pipe_writer,
|
pipe_writer,
|
||||||
balance_meta: Optional[DPBalanceMeta] = None,
|
balance_meta: Optional[DPBalanceMeta] = None,
|
||||||
):
|
):
|
||||||
if server_args.enable_trace:
|
# Generate the logger prefix
|
||||||
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
|
||||||
if server_args.disaggregation_mode == "null":
|
|
||||||
thread_label = "Scheduler"
|
|
||||||
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
|
||||||
|
|
||||||
if (numa_node := server_args.numa_node) is not None:
|
|
||||||
numa_bind_to_node(numa_node[gpu_id])
|
|
||||||
|
|
||||||
# Generate the prefix
|
|
||||||
prefix = ""
|
prefix = ""
|
||||||
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||||
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
||||||
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||||
if dp_rank is not None:
|
if dp_rank is not None:
|
||||||
prefix += f" DP{dp_rank}"
|
prefix += f" DP{dp_rank}"
|
||||||
if server_args.tp_size > 1:
|
if server_args.tp_size > 1:
|
||||||
@@ -2821,10 +2809,6 @@ def run_scheduler_process(
|
|||||||
kill_itself_when_parent_died()
|
kill_itself_when_parent_died()
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
|
||||||
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
|
||||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
|
||||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
configure_logger(server_args, prefix=prefix)
|
configure_logger(server_args, prefix=prefix)
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
@@ -2832,6 +2816,15 @@ def run_scheduler_process(
|
|||||||
# Set cpu affinity to this gpu process
|
# Set cpu affinity to this gpu process
|
||||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||||
|
if (numa_node := server_args.numa_node) is not None:
|
||||||
|
numa_bind_to_node(numa_node[gpu_id])
|
||||||
|
|
||||||
|
# Set up tracing
|
||||||
|
if server_args.enable_trace:
|
||||||
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
||||||
|
if server_args.disaggregation_mode == "null":
|
||||||
|
thread_label = "Scheduler"
|
||||||
|
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
||||||
|
|
||||||
# Create a scheduler and run the event loop
|
# Create a scheduler and run the event loop
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -47,8 +47,11 @@ class SchedulerMetricsMixin:
|
|||||||
self.spec_num_total_forward_ct = 0
|
self.spec_num_total_forward_ct = 0
|
||||||
self.cum_spec_accept_length = 0
|
self.cum_spec_accept_length = 0
|
||||||
self.cum_spec_accept_count = 0
|
self.cum_spec_accept_count = 0
|
||||||
self.total_retracted_reqs = 0
|
self.kv_transfer_speed_gb_s: float = 0.0
|
||||||
|
self.kv_transfer_latency_ms: float = 0.0
|
||||||
|
|
||||||
self.stats = SchedulerStats()
|
self.stats = SchedulerStats()
|
||||||
|
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
engine_type = "unified"
|
engine_type = "unified"
|
||||||
labels = {
|
labels = {
|
||||||
@@ -82,12 +85,14 @@ class SchedulerMetricsMixin:
|
|||||||
adder: PrefillAdder,
|
adder: PrefillAdder,
|
||||||
can_run_list: List[Req],
|
can_run_list: List[Req],
|
||||||
running_bs: int,
|
running_bs: int,
|
||||||
|
running_bs_offline_batch: int,
|
||||||
):
|
):
|
||||||
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
||||||
self.last_prefill_stats_tic = time.perf_counter()
|
self.last_prefill_stats_tic = time.perf_counter()
|
||||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||||
self.last_prefill_tokens = adder.log_input_tokens
|
self.last_prefill_tokens = adder.log_input_tokens
|
||||||
|
|
||||||
|
# TODO: generalize this for various memory pools
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
(
|
(
|
||||||
full_num_used,
|
full_num_used,
|
||||||
@@ -101,51 +106,53 @@ class SchedulerMetricsMixin:
|
|||||||
) = self._get_swa_token_info()
|
) = self._get_swa_token_info()
|
||||||
num_used = max(full_num_used, swa_num_used)
|
num_used = max(full_num_used, swa_num_used)
|
||||||
token_usage = max(full_token_usage, swa_token_usage)
|
token_usage = max(full_token_usage, swa_token_usage)
|
||||||
token_msg = (
|
token_usage_msg = (
|
||||||
f"full token usage: {full_token_usage:.2f}, "
|
f"full token usage: {full_token_usage:.2f}, "
|
||||||
f"swa token usage: {swa_token_usage:.2f}, "
|
f"swa token usage: {swa_token_usage:.2f}, "
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_used, token_usage, _, _ = self._get_token_info()
|
num_used, token_usage, _, _ = self._get_token_info()
|
||||||
token_msg = f"token usage: {token_usage:.2f}, "
|
token_usage_msg = f"token usage: {token_usage:.2f}, "
|
||||||
|
|
||||||
num_new_seq = len(can_run_list)
|
|
||||||
f = (
|
f = (
|
||||||
f"Prefill batch. "
|
f"Prefill batch. "
|
||||||
f"#new-seq: {num_new_seq}, "
|
f"#new-seq: {len(can_run_list)}, "
|
||||||
f"#new-token: {adder.log_input_tokens}, "
|
f"#new-token: {adder.log_input_tokens}, "
|
||||||
f"#cached-token: {adder.log_hit_tokens}, "
|
f"#cached-token: {adder.log_hit_tokens}, "
|
||||||
f"{token_msg}"
|
f"{token_usage_msg}"
|
||||||
|
f"#running-req: {running_bs}, "
|
||||||
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
||||||
f += f"#queue-req: {len(self.waiting_queue)}, "
|
f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
|
||||||
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
|
||||||
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
|
||||||
else:
|
|
||||||
f += f"#running-req: {running_bs}, "
|
|
||||||
f += f"#queue-req: {len(self.waiting_queue)}, "
|
|
||||||
|
|
||||||
logger.info(f)
|
logger.info(f)
|
||||||
|
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
|
# Basics
|
||||||
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
||||||
|
|
||||||
cache_hit_rate = (
|
cache_hit_rate = (
|
||||||
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stats.num_running_reqs = running_bs
|
self.stats.num_running_reqs = running_bs
|
||||||
|
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
|
||||||
self.stats.num_used_tokens = num_used
|
self.stats.num_used_tokens = num_used
|
||||||
self.stats.token_usage = round(token_usage, 2)
|
self.stats.token_usage = token_usage
|
||||||
|
if self.is_hybrid:
|
||||||
|
self.stats.swa_token_usage = swa_token_usage
|
||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
self.stats.cache_hit_rate = cache_hit_rate
|
self.stats.cache_hit_rate = cache_hit_rate
|
||||||
|
|
||||||
total_queue_latency = 0
|
# Retract
|
||||||
for req in can_run_list:
|
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
||||||
total_queue_latency += req.queue_time_end - req.queue_time_start
|
self.stats.num_paused_reqs = self.num_paused_reqs
|
||||||
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
self.num_retracted_reqs = self.num_paused_reqs = 0
|
||||||
|
|
||||||
|
# PD disaggregation
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.stats.num_prefill_prealloc_queue_reqs = len(
|
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||||
self.disagg_prefill_bootstrap_queue.queue
|
self.disagg_prefill_bootstrap_queue.queue
|
||||||
@@ -153,7 +160,18 @@ class SchedulerMetricsMixin:
|
|||||||
self.stats.num_prefill_inflight_queue_reqs = len(
|
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||||
self.disagg_prefill_inflight_queue
|
self.disagg_prefill_inflight_queue
|
||||||
)
|
)
|
||||||
|
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
|
||||||
|
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||||
|
self.disagg_decode_prealloc_queue.queue
|
||||||
|
)
|
||||||
|
self.stats.num_decode_transfer_queue_reqs = len(
|
||||||
|
self.disagg_decode_transfer_queue.queue
|
||||||
|
)
|
||||||
|
|
||||||
|
# Others
|
||||||
|
self.calculate_utilization()
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
self._emit_kv_metrics()
|
self._emit_kv_metrics()
|
||||||
self._publish_kv_events()
|
self._publish_kv_events()
|
||||||
@@ -166,8 +184,12 @@ class SchedulerMetricsMixin:
|
|||||||
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
||||||
self.last_decode_stats_tic = time.perf_counter()
|
self.last_decode_stats_tic = time.perf_counter()
|
||||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||||
|
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
num_running_reqs = len(batch.reqs)
|
num_running_reqs = len(batch.reqs)
|
||||||
|
num_running_reqs_offline_batch = 0
|
||||||
|
|
||||||
|
# TODO: generalize this for various memory pools
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
(
|
(
|
||||||
full_num_used,
|
full_num_used,
|
||||||
@@ -181,7 +203,7 @@ class SchedulerMetricsMixin:
|
|||||||
) = self._get_swa_token_info()
|
) = self._get_swa_token_info()
|
||||||
num_used = max(full_num_used, swa_num_used)
|
num_used = max(full_num_used, swa_num_used)
|
||||||
token_usage = max(full_token_usage, swa_token_usage)
|
token_usage = max(full_token_usage, swa_token_usage)
|
||||||
token_msg = (
|
token_usage_msg = (
|
||||||
f"#full token: {full_num_used}, "
|
f"#full token: {full_num_used}, "
|
||||||
f"full token usage: {full_token_usage:.2f}, "
|
f"full token usage: {full_token_usage:.2f}, "
|
||||||
f"#swa token: {swa_num_used}, "
|
f"#swa token: {swa_num_used}, "
|
||||||
@@ -189,14 +211,14 @@ class SchedulerMetricsMixin:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_used, token_usage, _, _ = self._get_token_info()
|
num_used, token_usage, _, _ = self._get_token_info()
|
||||||
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
|
||||||
|
|
||||||
if RECORD_STEP_TIME:
|
if RECORD_STEP_TIME:
|
||||||
self.step_time_dict[num_running_reqs].append(
|
self.step_time_dict[num_running_reqs].append(
|
||||||
gap_latency / self.server_args.decode_log_interval
|
gap_latency / self.server_args.decode_log_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
|
||||||
|
|
||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
spec_accept_length = 0
|
spec_accept_length = 0
|
||||||
@@ -208,41 +230,66 @@ class SchedulerMetricsMixin:
|
|||||||
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
||||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||||
msg += f"accept len: {spec_accept_length:.2f}, "
|
msg += f"accept len: {spec_accept_length:.2f}, "
|
||||||
|
cache_hit_rate = 0.0
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||||
|
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
|
||||||
|
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
|
||||||
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||||
|
|
||||||
msg += (
|
msg += (
|
||||||
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
|
f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
|
||||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||||
f"#queue-req: {len(self.waiting_queue)}, "
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
|
# Basics
|
||||||
self.stats.num_running_reqs = num_running_reqs
|
self.stats.num_running_reqs = num_running_reqs
|
||||||
|
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
|
||||||
self.stats.num_used_tokens = num_used
|
self.stats.num_used_tokens = num_used
|
||||||
self.stats.token_usage = round(token_usage, 2)
|
self.stats.token_usage = token_usage
|
||||||
self.stats.cache_hit_rate = 0.0
|
if self.is_hybrid:
|
||||||
|
self.stats.swa_token_usage = swa_token_usage
|
||||||
self.stats.gen_throughput = self.last_gen_throughput
|
self.stats.gen_throughput = self.last_gen_throughput
|
||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
|
self.stats.cache_hit_rate = cache_hit_rate
|
||||||
self.stats.spec_accept_length = spec_accept_length
|
self.stats.spec_accept_length = spec_accept_length
|
||||||
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
|
||||||
self.stats.avg_request_queue_latency = 0.0
|
# Retract
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
||||||
|
self.stats.num_paused_reqs = self.num_paused_reqs
|
||||||
|
self.num_retracted_reqs = self.num_paused_reqs = 0
|
||||||
|
|
||||||
|
# PD disaggregation
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self.stats.num_prefill_prealloc_queue_reqs = len(
|
||||||
|
self.disagg_prefill_bootstrap_queue.queue
|
||||||
|
)
|
||||||
|
self.stats.num_prefill_inflight_queue_reqs = len(
|
||||||
|
self.disagg_prefill_inflight_queue
|
||||||
|
)
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.stats.num_decode_prealloc_queue_reqs = len(
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
||||||
self.disagg_decode_prealloc_queue.queue
|
self.disagg_decode_prealloc_queue.queue
|
||||||
)
|
)
|
||||||
self.stats.num_decode_transfer_queue_reqs = len(
|
self.stats.num_decode_transfer_queue_reqs = len(
|
||||||
self.disagg_decode_transfer_queue.queue
|
self.disagg_decode_transfer_queue.queue
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Others
|
||||||
|
self.calculate_utilization()
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
self._emit_kv_metrics()
|
self._emit_kv_metrics()
|
||||||
self._publish_kv_events()
|
self._publish_kv_events()
|
||||||
|
|
||||||
def _emit_kv_metrics(self: Scheduler):
|
def _emit_kv_metrics(self: Scheduler):
|
||||||
|
if not self.enable_kv_cache_events:
|
||||||
|
return
|
||||||
|
|
||||||
kv_metrics = KvMetrics()
|
kv_metrics = KvMetrics()
|
||||||
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
||||||
kv_metrics.request_total_slots = self.max_running_requests
|
kv_metrics.request_total_slots = self.max_running_requests
|
||||||
@@ -259,7 +306,9 @@ class SchedulerMetricsMixin:
|
|||||||
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
||||||
|
|
||||||
def _publish_kv_events(self: Scheduler):
|
def _publish_kv_events(self: Scheduler):
|
||||||
if self.enable_kv_cache_events:
|
if not self.enable_kv_cache_events:
|
||||||
|
return
|
||||||
|
|
||||||
events = self.tree_cache.take_events()
|
events = self.tree_cache.take_events()
|
||||||
if events:
|
if events:
|
||||||
batch = KVEventBatch(ts=time.time(), events=events)
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
@@ -349,3 +398,17 @@ class SchedulerMetricsMixin:
|
|||||||
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
||||||
meta.set_shared_onfly_info(onfly_list)
|
meta.set_shared_onfly_info(onfly_list)
|
||||||
meta.set_shared_local_tokens(local_tokens)
|
meta.set_shared_local_tokens(local_tokens)
|
||||||
|
|
||||||
|
def calculate_utilization(self):
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self.stats.utilization = -1
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
self.stats.max_running_requests_under_SLO is not None
|
||||||
|
and self.stats.max_running_requests_under_SLO > 0
|
||||||
|
):
|
||||||
|
self.stats.utilization = max(
|
||||||
|
self.stats.num_running_reqs
|
||||||
|
/ self.stats.max_running_requests_under_SLO,
|
||||||
|
self.stats.token_usage / 0.9,
|
||||||
|
)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
req.time_stats.completion_time = time.time()
|
req.time_stats.completion_time = time.perf_counter()
|
||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
# This updates radix so others can match
|
# This updates radix so others can match
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
@@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
else:
|
else:
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
|
||||||
req.time_stats.completion_time = time.time()
|
req.time_stats.completion_time = time.perf_counter()
|
||||||
|
|
||||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||||
# speculative worker handles logprob in speculative decoding
|
# speculative worker handles logprob in speculative decoding
|
||||||
@@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
and self.tp_rank == 0
|
and self.tp_rank == 0
|
||||||
and self.server_args.enable_request_time_stats_logging
|
and self.server_args.enable_request_time_stats_logging
|
||||||
):
|
):
|
||||||
|
print(f"{req.finished_reason=}")
|
||||||
req.log_time_stats()
|
req.log_time_stats()
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -24,6 +25,7 @@ import zmq
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
ClearHiCacheReqInput,
|
ClearHiCacheReqInput,
|
||||||
ClearHiCacheReqOutput,
|
ClearHiCacheReqOutput,
|
||||||
|
CloseSessionReqInput,
|
||||||
DestroyWeightsUpdateGroupReqInput,
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
DestroyWeightsUpdateGroupReqOutput,
|
DestroyWeightsUpdateGroupReqOutput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
@@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
LoadLoRAAdapterReqOutput,
|
LoadLoRAAdapterReqOutput,
|
||||||
LoRAUpdateResult,
|
LoRAUpdateResult,
|
||||||
MultiTokenizerWrapper,
|
MultiTokenizerWrapper,
|
||||||
|
OpenSessionReqInput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
ProfileReqOutput,
|
ProfileReqOutput,
|
||||||
ProfileReqType,
|
ProfileReqType,
|
||||||
@@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin:
|
|||||||
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
|
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
|
||||||
req = GetLoadReqInput()
|
req = GetLoadReqInput()
|
||||||
return await self.get_load_communicator(req)
|
return await self.get_load_communicator(req)
|
||||||
|
|
||||||
|
async def open_session(
|
||||||
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||||
|
):
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
|
||||||
|
if obj.session_id is None:
|
||||||
|
obj.session_id = uuid.uuid4().hex
|
||||||
|
elif obj.session_id in self.session_futures:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.server_args.tokenizer_worker_num > 1:
|
||||||
|
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
||||||
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
|
self.session_futures[obj.session_id] = asyncio.Future()
|
||||||
|
session_id = await self.session_futures[obj.session_id]
|
||||||
|
del self.session_futures[obj.session_id]
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
async def close_session(
|
||||||
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||||
|
):
|
||||||
|
await self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
|
def get_log_request_metadata(self):
|
||||||
|
max_length = None
|
||||||
|
skip_names = None
|
||||||
|
out_skip_names = None
|
||||||
|
if self.log_requests:
|
||||||
|
if self.log_requests_level == 0:
|
||||||
|
max_length = 1 << 30
|
||||||
|
skip_names = set(
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
"input_ids",
|
||||||
|
"input_embeds",
|
||||||
|
"image_data",
|
||||||
|
"audio_data",
|
||||||
|
"lora_path",
|
||||||
|
"sampling_params",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_skip_names = set(
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
"output_ids",
|
||||||
|
"embedding",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif self.log_requests_level == 1:
|
||||||
|
max_length = 1 << 30
|
||||||
|
skip_names = set(
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
"input_ids",
|
||||||
|
"input_embeds",
|
||||||
|
"image_data",
|
||||||
|
"audio_data",
|
||||||
|
"lora_path",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_skip_names = set(
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
"output_ids",
|
||||||
|
"embedding",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif self.log_requests_level == 2:
|
||||||
|
max_length = 2048
|
||||||
|
elif self.log_requests_level == 3:
|
||||||
|
max_length = 1 << 30
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
||||||
|
)
|
||||||
|
return max_length, skip_names, out_skip_names
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.crash_dump_folder = server_args.crash_dump_folder
|
self.crash_dump_folder = server_args.crash_dump_folder
|
||||||
|
self.enable_trace = server_args.enable_trace
|
||||||
|
|
||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
@@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
# If it's a single value, add worker_id prefix
|
# If it's a single value, add worker_id prefix
|
||||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||||
|
|
||||||
if obj.is_single:
|
if self.enable_trace:
|
||||||
bootstrap_room = (
|
self._trace_request_start(obj, created_time)
|
||||||
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
|
||||||
)
|
|
||||||
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
|
||||||
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
|
||||||
else:
|
|
||||||
for i in range(len(obj.rid)):
|
|
||||||
bootstrap_room = (
|
|
||||||
obj.bootstrap_room[i]
|
|
||||||
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
|
||||||
trace_slice_start(
|
|
||||||
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length, skip_names, _ = self.log_request_metadata
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
@@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
req = AbortReq(rid, abort_all)
|
req = AbortReq(rid, abort_all)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
self.metrics_collector.observe_one_aborted_request()
|
# TODO: also use custom_labels from the request
|
||||||
|
self.metrics_collector.observe_one_aborted_request(
|
||||||
|
self.metrics_collector.labels
|
||||||
|
)
|
||||||
|
|
||||||
async def pause_generation(self):
|
async def pause_generation(self):
|
||||||
async with self.is_pause_cond:
|
async with self.is_pause_cond:
|
||||||
@@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
all_paused_requests = [r.num_paused_requests for r in result]
|
all_paused_requests = [r.num_paused_requests for r in result]
|
||||||
return all_success, all_message, all_paused_requests
|
return all_success, all_message, all_paused_requests
|
||||||
|
|
||||||
async def open_session(
|
|
||||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
|
||||||
):
|
|
||||||
self.auto_create_handle_loop()
|
|
||||||
|
|
||||||
if obj.session_id is None:
|
|
||||||
obj.session_id = uuid.uuid4().hex
|
|
||||||
elif obj.session_id in self.session_futures:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self.server_args.tokenizer_worker_num > 1:
|
|
||||||
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
|
||||||
|
|
||||||
self.session_futures[obj.session_id] = asyncio.Future()
|
|
||||||
session_id = await self.session_futures[obj.session_id]
|
|
||||||
del self.session_futures[obj.session_id]
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
async def close_session(
|
|
||||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
|
||||||
):
|
|
||||||
await self.send_to_scheduler.send_pyobj(obj)
|
|
||||||
|
|
||||||
def get_log_request_metadata(self):
|
|
||||||
max_length = None
|
|
||||||
skip_names = None
|
|
||||||
out_skip_names = None
|
|
||||||
if self.log_requests:
|
|
||||||
if self.log_requests_level == 0:
|
|
||||||
max_length = 1 << 30
|
|
||||||
skip_names = set(
|
|
||||||
[
|
|
||||||
"text",
|
|
||||||
"input_ids",
|
|
||||||
"input_embeds",
|
|
||||||
"image_data",
|
|
||||||
"audio_data",
|
|
||||||
"lora_path",
|
|
||||||
"sampling_params",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
out_skip_names = set(
|
|
||||||
[
|
|
||||||
"text",
|
|
||||||
"output_ids",
|
|
||||||
"embedding",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif self.log_requests_level == 1:
|
|
||||||
max_length = 1 << 30
|
|
||||||
skip_names = set(
|
|
||||||
[
|
|
||||||
"text",
|
|
||||||
"input_ids",
|
|
||||||
"input_embeds",
|
|
||||||
"image_data",
|
|
||||||
"audio_data",
|
|
||||||
"lora_path",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
out_skip_names = set(
|
|
||||||
[
|
|
||||||
"text",
|
|
||||||
"output_ids",
|
|
||||||
"embedding",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif self.log_requests_level == 2:
|
|
||||||
max_length = 2048
|
|
||||||
elif self.log_requests_level == 3:
|
|
||||||
max_length = 1 << 30
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
|
||||||
)
|
|
||||||
return max_length, skip_names, out_skip_names
|
|
||||||
|
|
||||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||||
if obj.log_requests is not None:
|
if obj.log_requests is not None:
|
||||||
self.log_requests = obj.log_requests
|
self.log_requests = obj.log_requests
|
||||||
@@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
# Drain requests
|
# Drain requests
|
||||||
while True:
|
while True:
|
||||||
remain_num_req = len(self.rid_to_state)
|
remain_num_req = len(self.rid_to_state)
|
||||||
|
remaining_rids = list(self.rid_to_state.keys())
|
||||||
|
|
||||||
if self.server_status == ServerStatus.UnHealthy:
|
if self.server_status == ServerStatus.UnHealthy:
|
||||||
# if health check failed, we should exit immediately
|
# if health check failed, we should exit immediately
|
||||||
logger.error(
|
logger.error(
|
||||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
"Signal SIGTERM received while health check failed. Force exiting."
|
||||||
remain_num_req,
|
|
||||||
)
|
)
|
||||||
self.dump_requests_before_crash()
|
self.dump_requests_before_crash()
|
||||||
break
|
break
|
||||||
@@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
||||||
# if force shutdown flag set, exit immediately
|
# if force shutdown flag set, exit immediately
|
||||||
logger.error(
|
logger.error(
|
||||||
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
|
"Signal SIGTERM received while force shutdown flag set. Force exiting."
|
||||||
remain_num_req,
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
|
||||||
)
|
)
|
||||||
if remain_num_req > 0:
|
if remain_num_req > 0:
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
@@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
||||||
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
||||||
|
|
||||||
|
def _trace_request_start(
|
||||||
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
created_time: Optional[float] = None,
|
||||||
|
):
|
||||||
|
if obj.is_single:
|
||||||
|
bootstrap_room = (
|
||||||
|
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
||||||
|
)
|
||||||
|
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
||||||
|
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
||||||
|
else:
|
||||||
|
for i in range(len(obj.rid)):
|
||||||
|
bootstrap_room = (
|
||||||
|
obj.bootstrap_room[i]
|
||||||
|
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
||||||
|
trace_slice_start(
|
||||||
|
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ServerStatus(Enum):
|
class ServerStatus(Enum):
|
||||||
Up = "Up"
|
Up = "Up"
|
||||||
@@ -1933,7 +1866,7 @@ class SignalHandler:
|
|||||||
|
|
||||||
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Received sigquit from a child process. It usually means the child failed."
|
f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
|
||||||
)
|
)
|
||||||
self.tokenizer_manager.dump_requests_before_crash()
|
self.tokenizer_manager.dump_requests_before_crash()
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|||||||
@@ -14,9 +14,9 @@
|
|||||||
"""Utilities for Prometheus Metrics Collection."""
|
"""Utilities for Prometheus Metrics Collection."""
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
|
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
@@ -34,6 +34,7 @@ class TimeStats:
|
|||||||
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
|
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
disagg_mode: DisaggregationMode = DisaggregationMode.NULL
|
||||||
lb_entry_time: float = 0.0
|
lb_entry_time: float = 0.0
|
||||||
wait_queue_entry_time: float = 0.0
|
wait_queue_entry_time: float = 0.0
|
||||||
forward_entry_time: float = 0.0
|
forward_entry_time: float = 0.0
|
||||||
@@ -43,20 +44,11 @@ class TimeStats:
|
|||||||
decode_prealloc_queue_entry_time: float = 0.0
|
decode_prealloc_queue_entry_time: float = 0.0
|
||||||
decode_transfer_queue_entry_time: float = 0.0
|
decode_transfer_queue_entry_time: float = 0.0
|
||||||
|
|
||||||
class RequestType(Enum):
|
|
||||||
UNIFIED = "unified"
|
|
||||||
PREFILL = "prefill"
|
|
||||||
DECODE = "decode"
|
|
||||||
INVALID = "invalid"
|
|
||||||
|
|
||||||
def get_queueing_time(self) -> float:
|
def get_queueing_time(self) -> float:
|
||||||
return self.forward_entry_time - self.wait_queue_entry_time
|
return self.forward_entry_time - self.wait_queue_entry_time
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def convert_to_duration(self) -> str:
|
||||||
# if unified
|
if self.disagg_mode == DisaggregationMode.NULL:
|
||||||
_type = self.get_type()
|
|
||||||
|
|
||||||
if _type == self.RequestType.UNIFIED:
|
|
||||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||||
forward_duration = self.completion_time - self.forward_entry_time
|
forward_duration = self.completion_time - self.forward_entry_time
|
||||||
|
|
||||||
@@ -65,30 +57,28 @@ class TimeStats:
|
|||||||
queue_duration >= 0 and forward_duration >= 0
|
queue_duration >= 0 and forward_duration >= 0
|
||||||
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||||
|
|
||||||
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
|
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}"
|
||||||
elif _type == self.RequestType.PREFILL:
|
elif self.disagg_mode == DisaggregationMode.PREFILL:
|
||||||
bootstrap_duration = (
|
bootstrap_duration = (
|
||||||
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
|
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
|
||||||
)
|
)
|
||||||
|
|
||||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||||
|
|
||||||
forward_duration = self.completion_time - self.forward_entry_time
|
forward_duration = self.completion_time - self.forward_entry_time
|
||||||
|
|
||||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||||
|
if self.wait_queue_entry_time > 0:
|
||||||
assert (
|
assert (
|
||||||
bootstrap_duration >= 0
|
bootstrap_duration >= 0
|
||||||
and queue_duration >= 0
|
and queue_duration >= 0
|
||||||
and forward_duration >= 0
|
and forward_duration >= 0
|
||||||
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||||
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
|
|
||||||
# if decode
|
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}"
|
||||||
elif _type == self.RequestType.DECODE:
|
elif self.disagg_mode == DisaggregationMode.DECODE:
|
||||||
prealloc_duration = (
|
prealloc_duration = (
|
||||||
self.decode_transfer_queue_entry_time
|
self.decode_transfer_queue_entry_time
|
||||||
- self.decode_prealloc_queue_entry_time
|
- self.decode_prealloc_queue_entry_time
|
||||||
)
|
)
|
||||||
|
|
||||||
transfer_duration = (
|
transfer_duration = (
|
||||||
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
|
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
|
||||||
)
|
)
|
||||||
@@ -96,42 +86,30 @@ class TimeStats:
|
|||||||
forward_duration = self.completion_time - self.forward_entry_time
|
forward_duration = self.completion_time - self.forward_entry_time
|
||||||
|
|
||||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||||
|
if self.wait_queue_entry_time > 0:
|
||||||
assert (
|
assert (
|
||||||
prealloc_duration >= 0
|
prealloc_duration >= 0
|
||||||
and transfer_duration >= 0
|
and transfer_duration >= 0
|
||||||
and queue_duration >= 0
|
and queue_duration >= 0
|
||||||
and forward_duration >= 0
|
and forward_duration >= 0
|
||||||
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}"
|
||||||
|
|
||||||
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
|
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}"
|
||||||
else:
|
else:
|
||||||
return "Invalid Time Stats"
|
return "Unknown Time Stats"
|
||||||
|
|
||||||
def format_duration(self, duration: float) -> str:
|
def format_duration(self, duration: float) -> str:
|
||||||
return f"{duration * 1e3:.2f}ms"
|
return f"{duration * 1e3:.2f}ms"
|
||||||
|
|
||||||
def get_type(self) -> RequestType:
|
def disagg_mode_str(self) -> str:
|
||||||
"""Determine the type of request based on timestamp values."""
|
if self.disagg_mode == DisaggregationMode.NULL:
|
||||||
if (
|
return "unified"
|
||||||
self.prefill_bootstrap_queue_entry_time == 0.0
|
elif self.disagg_mode == DisaggregationMode.DECODE:
|
||||||
and self.prefill_transfer_queue_entry_time == 0.0
|
return "decode"
|
||||||
and self.decode_prealloc_queue_entry_time == 0.0
|
elif self.disagg_mode == DisaggregationMode.PREFILL:
|
||||||
and self.decode_transfer_queue_entry_time == 0.0
|
return "prefill"
|
||||||
):
|
|
||||||
return self.RequestType.UNIFIED
|
|
||||||
elif (
|
|
||||||
self.prefill_bootstrap_queue_entry_time > 0.0
|
|
||||||
and self.prefill_transfer_queue_entry_time > 0.0
|
|
||||||
):
|
|
||||||
return self.RequestType.PREFILL
|
|
||||||
elif (
|
|
||||||
self.decode_prealloc_queue_entry_time > 0.0
|
|
||||||
and self.decode_transfer_queue_entry_time > 0.0
|
|
||||||
and self.wait_queue_entry_time > 0.0
|
|
||||||
):
|
|
||||||
return self.RequestType.DECODE
|
|
||||||
else:
|
else:
|
||||||
return self.RequestType.INVALID
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -145,12 +123,15 @@ class SchedulerStats:
|
|||||||
num_queue_reqs: int = 0
|
num_queue_reqs: int = 0
|
||||||
num_grammar_queue_reqs: int = 0
|
num_grammar_queue_reqs: int = 0
|
||||||
num_running_reqs_offline_batch: int = 0
|
num_running_reqs_offline_batch: int = 0
|
||||||
avg_request_queue_latency: float = 0.0
|
|
||||||
cache_hit_rate: float = 0.0
|
cache_hit_rate: float = 0.0
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_accept_length: float = 0.0
|
spec_accept_length: float = 0.0
|
||||||
|
|
||||||
|
# Retract
|
||||||
|
num_retracted_reqs: int = 0
|
||||||
|
num_paused_reqs: int = 0
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
num_prefill_prealloc_queue_reqs: int = 0
|
num_prefill_prealloc_queue_reqs: int = 0
|
||||||
num_prefill_inflight_queue_reqs: int = 0
|
num_prefill_inflight_queue_reqs: int = 0
|
||||||
@@ -159,11 +140,6 @@ class SchedulerStats:
|
|||||||
kv_transfer_speed_gb_s: float = 0.0
|
kv_transfer_speed_gb_s: float = 0.0
|
||||||
kv_transfer_latency_ms: float = 0.0
|
kv_transfer_latency_ms: float = 0.0
|
||||||
|
|
||||||
# Retract
|
|
||||||
total_retracted_reqs: int = 0
|
|
||||||
num_retracted_reqs: int = 0
|
|
||||||
num_paused_reqs: int = 0
|
|
||||||
|
|
||||||
# Utilization
|
# Utilization
|
||||||
utilization: float = 0.0
|
utilization: float = 0.0
|
||||||
max_running_requests_under_SLO: Optional[int] = None
|
max_running_requests_under_SLO: Optional[int] = None
|
||||||
@@ -230,12 +206,6 @@ class SchedulerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="mostrecent",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
self.avg_request_queue_latency = Gauge(
|
|
||||||
name="sglang:avg_request_queue_latency",
|
|
||||||
documentation="The average request queue latency for the last batch of requests in seconds.",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
multiprocess_mode="mostrecent",
|
|
||||||
)
|
|
||||||
self.cache_hit_rate = Gauge(
|
self.cache_hit_rate = Gauge(
|
||||||
name="sglang:cache_hit_rate",
|
name="sglang:cache_hit_rate",
|
||||||
documentation="The prefix cache hit rate.",
|
documentation="The prefix cache hit rate.",
|
||||||
@@ -251,6 +221,18 @@ class SchedulerMetricsCollector:
|
|||||||
multiprocess_mode="mostrecent",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Retract
|
||||||
|
self.num_retracted_reqs = Gauge(
|
||||||
|
name="sglang:num_retracted_reqs",
|
||||||
|
documentation="The number of retracted requests.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
self.num_paused_reqs = Gauge(
|
||||||
|
name="sglang:num_paused_reqs",
|
||||||
|
documentation="The number of paused requests by async weight sync.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
self.num_prefill_prealloc_queue_reqs = Gauge(
|
self.num_prefill_prealloc_queue_reqs = Gauge(
|
||||||
name="sglang:num_prefill_prealloc_queue_reqs",
|
name="sglang:num_prefill_prealloc_queue_reqs",
|
||||||
@@ -299,24 +281,6 @@ class SchedulerMetricsCollector:
|
|||||||
multiprocess_mode="mostrecent",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retract
|
|
||||||
self.total_retracted_reqs = Gauge(
|
|
||||||
name="sglang:total_retracted_reqs",
|
|
||||||
documentation="The total number of retracted requests due to kvcache full.",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
multiprocess_mode="mostrecent",
|
|
||||||
)
|
|
||||||
self.num_retracted_reqs = Gauge(
|
|
||||||
name="sglang:num_retracted_reqs",
|
|
||||||
documentation="The number of retracted requests.",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
)
|
|
||||||
self.num_paused_reqs = Gauge(
|
|
||||||
name="sglang:num_paused_reqs",
|
|
||||||
documentation="The number of paused requests by async weight sync.",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Utilization
|
# Utilization
|
||||||
self.utilization = Gauge(
|
self.utilization = Gauge(
|
||||||
name="sglang:utilization",
|
name="sglang:utilization",
|
||||||
@@ -347,7 +311,7 @@ class SchedulerMetricsCollector:
|
|||||||
|
|
||||||
# Additional queueing time histogram
|
# Additional queueing time histogram
|
||||||
self.queue_time = Histogram(
|
self.queue_time = Histogram(
|
||||||
name="sglang:queue_time_s",
|
name="sglang:queue_time_seconds",
|
||||||
documentation="Histogram of queueing time in seconds.",
|
documentation="Histogram of queueing time in seconds.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
@@ -513,8 +477,8 @@ class SchedulerMetricsCollector:
|
|||||||
buckets=tree_traversal_time_buckets,
|
buckets=tree_traversal_time_buckets,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.request_latency_seconds = Histogram(
|
self.per_stage_req_latency_seconds = Histogram(
|
||||||
name="sglang:request_latency_seconds",
|
name="sglang:per_stage_req_latency_seconds",
|
||||||
documentation="The latency of each stage of requests.",
|
documentation="The latency of each stage of requests.",
|
||||||
# captures latency in range [1ms - ~1191s]
|
# captures latency in range [1ms - ~1191s]
|
||||||
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
|
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
|
||||||
@@ -525,7 +489,7 @@ class SchedulerMetricsCollector:
|
|||||||
# Convenience function for logging to gauge.
|
# Convenience function for logging to gauge.
|
||||||
gauge.labels(**self.labels).set(data)
|
gauge.labels(**self.labels).set(data)
|
||||||
|
|
||||||
def log_histogram(self, histogram, data: Union[int, float]) -> None:
|
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||||
histogram.labels(**self.labels).observe(data)
|
histogram.labels(**self.labels).observe(data)
|
||||||
|
|
||||||
def increment_bootstrap_failed_reqs(self) -> None:
|
def increment_bootstrap_failed_reqs(self) -> None:
|
||||||
@@ -534,9 +498,12 @@ class SchedulerMetricsCollector:
|
|||||||
def increment_transfer_failed_reqs(self) -> None:
|
def increment_transfer_failed_reqs(self) -> None:
|
||||||
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
||||||
|
|
||||||
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
|
def observe_per_stage_req_latency(self, stage: str, latency: float) -> None:
|
||||||
labels_with_stage = {**self.labels, "stage": stage}
|
labels_with_stage = {**self.labels, "stage": stage}
|
||||||
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
|
self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency)
|
||||||
|
|
||||||
|
def observe_queue_time(self, latency: float) -> None:
|
||||||
|
self._log_histogram(self.queue_time, latency)
|
||||||
|
|
||||||
def log_stats(self, stats: SchedulerStats) -> None:
|
def log_stats(self, stats: SchedulerStats) -> None:
|
||||||
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||||
@@ -550,7 +517,6 @@ class SchedulerMetricsCollector:
|
|||||||
self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
|
self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
|
||||||
)
|
)
|
||||||
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||||
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
|
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||||
@@ -572,7 +538,6 @@ class SchedulerMetricsCollector:
|
|||||||
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
|
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
|
||||||
|
|
||||||
# Retract
|
# Retract
|
||||||
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
|
|
||||||
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
|
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
|
||||||
self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
|
self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
|
||||||
|
|
||||||
@@ -596,19 +561,19 @@ class SchedulerMetricsCollector:
|
|||||||
def log_grammar_stats(self, grammar_stats) -> None:
|
def log_grammar_stats(self, grammar_stats) -> None:
|
||||||
# Duck-typed GrammarStats to avoid cross-package dependency
|
# Duck-typed GrammarStats to avoid cross-package dependency
|
||||||
if getattr(grammar_stats, "compilation_time", None) is not None:
|
if getattr(grammar_stats, "compilation_time", None) is not None:
|
||||||
self.log_histogram(
|
self._log_histogram(
|
||||||
self.grammar_compilation_time, grammar_stats.compilation_time
|
self.grammar_compilation_time, grammar_stats.compilation_time
|
||||||
)
|
)
|
||||||
if getattr(grammar_stats, "schema_count", None) is not None:
|
if getattr(grammar_stats, "schema_count", None) is not None:
|
||||||
self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
|
self._log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
|
||||||
if getattr(grammar_stats, "ebnf_size", None) is not None:
|
if getattr(grammar_stats, "ebnf_size", None) is not None:
|
||||||
self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
|
self._log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
|
||||||
tree_times = getattr(grammar_stats, "tree_traversal_time", None)
|
tree_times = getattr(grammar_stats, "tree_traversal_time", None)
|
||||||
if tree_times:
|
if tree_times:
|
||||||
max_time = max(tree_times)
|
max_time = max(tree_times)
|
||||||
avg_time = sum(tree_times) / len(tree_times)
|
avg_time = sum(tree_times) / len(tree_times)
|
||||||
self.log_histogram(self.grammar_tree_traversal_time_max, max_time)
|
self._log_histogram(self.grammar_tree_traversal_time_max, max_time)
|
||||||
self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
|
self._log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
|
||||||
if getattr(grammar_stats, "is_cache_hit", False):
|
if getattr(grammar_stats, "is_cache_hit", False):
|
||||||
self.num_grammar_cache_hit.labels(**self.labels).inc(1)
|
self.num_grammar_cache_hit.labels(**self.labels).inc(1)
|
||||||
if getattr(grammar_stats, "is_grammar_aborted", False):
|
if getattr(grammar_stats, "is_grammar_aborted", False):
|
||||||
@@ -714,7 +679,7 @@ class TokenizerMetricsCollector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.num_aborted_requests_total = Counter(
|
self.num_aborted_requests_total = Counter(
|
||||||
name="sglang:num_aborted_requests",
|
name="sglang:num_aborted_requests_total",
|
||||||
documentation="Number of requests aborted.",
|
documentation="Number of requests aborted.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
)
|
)
|
||||||
@@ -801,7 +766,7 @@ class TokenizerMetricsCollector:
|
|||||||
buckets=bucket_time_to_first_token,
|
buckets=bucket_time_to_first_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.histogram_inter_token_latency_seconds = Histogram(
|
self.histogram_inter_token_latency = Histogram(
|
||||||
name="sglang:inter_token_latency_seconds",
|
name="sglang:inter_token_latency_seconds",
|
||||||
documentation="Histogram of inter-token latency in seconds.",
|
documentation="Histogram of inter-token latency in seconds.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
@@ -815,14 +780,6 @@ class TokenizerMetricsCollector:
|
|||||||
buckets=bucket_e2e_request_latency,
|
buckets=bucket_e2e_request_latency,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Offline batch specific TTFB histogram
|
|
||||||
self.histogram_time_to_first_token_offline_batch = Histogram(
|
|
||||||
name="sglang:time_to_first_token_seconds_offline_batch",
|
|
||||||
documentation="Histogram of time to first token in seconds for offline batch requests.",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
buckets=bucket_time_to_first_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
def observe_one_finished_request(
|
def observe_one_finished_request(
|
||||||
self,
|
self,
|
||||||
labels: Dict[str, str],
|
labels: Dict[str, str],
|
||||||
@@ -846,14 +803,7 @@ class TokenizerMetricsCollector:
|
|||||||
float(generation_tokens)
|
float(generation_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
def observe_time_to_first_token(
|
def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
|
||||||
self, labels: Dict[str, str], value: float, type: str = ""
|
|
||||||
):
|
|
||||||
if type == "batch":
|
|
||||||
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
|
|
||||||
value
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.histogram_time_to_first_token.labels(**labels).observe(value)
|
self.histogram_time_to_first_token.labels(**labels).observe(value)
|
||||||
|
|
||||||
def check_time_to_first_token_straggler(self, value: float) -> bool:
|
def check_time_to_first_token_straggler(self, value: float) -> bool:
|
||||||
@@ -876,7 +826,7 @@ class TokenizerMetricsCollector:
|
|||||||
|
|
||||||
# A faster version of the Histogram::observe which observes multiple values at the same time.
|
# A faster version of the Histogram::observe which observes multiple values at the same time.
|
||||||
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
|
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
|
||||||
his = self.histogram_inter_token_latency_seconds.labels(**labels)
|
his = self.histogram_inter_token_latency.labels(**labels)
|
||||||
his._sum.inc(internval)
|
his._sum.inc(internval)
|
||||||
|
|
||||||
for i, bound in enumerate(his._upper_bounds):
|
for i, bound in enumerate(his._upper_bounds):
|
||||||
@@ -884,8 +834,8 @@ class TokenizerMetricsCollector:
|
|||||||
his._buckets[i].inc(num_new_tokens)
|
his._buckets[i].inc(num_new_tokens)
|
||||||
break
|
break
|
||||||
|
|
||||||
def observe_one_aborted_request(self):
|
def observe_one_aborted_request(self, labels: Dict[str, str]):
|
||||||
self.num_aborted_requests_total.labels(**self.labels).inc(1)
|
self.num_aborted_requests_total.labels(**labels).inc(1)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ctypes
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -23,7 +22,10 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.scheduler import Req
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
opentelemetry_imported = False
|
opentelemetry_imported = False
|
||||||
@@ -407,9 +409,11 @@ def trace_slice_start(
|
|||||||
ts: Optional[int] = None,
|
ts: Optional[int] = None,
|
||||||
anonymous: bool = False,
|
anonymous: bool = False,
|
||||||
):
|
):
|
||||||
|
if not tracing_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
rid = str(rid)
|
rid = str(rid)
|
||||||
if not tracing_enabled or rid not in reqs_context:
|
if rid not in reqs_context:
|
||||||
return
|
return
|
||||||
|
|
||||||
pid = threading.get_native_id()
|
pid = threading.get_native_id()
|
||||||
@@ -458,8 +462,11 @@ def trace_slice_end(
|
|||||||
auto_next_anon: bool = False,
|
auto_next_anon: bool = False,
|
||||||
thread_finish_flag: bool = False,
|
thread_finish_flag: bool = False,
|
||||||
):
|
):
|
||||||
|
if not tracing_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
rid = str(rid)
|
rid = str(rid)
|
||||||
if not tracing_enabled or rid not in reqs_context:
|
if rid not in reqs_context:
|
||||||
return
|
return
|
||||||
|
|
||||||
pid = threading.get_native_id()
|
pid = threading.get_native_id()
|
||||||
@@ -512,10 +519,13 @@ trace_slice = trace_slice_end
|
|||||||
|
|
||||||
# Add event to the current slice on the same thread with the same rid.
|
# Add event to the current slice on the same thread with the same rid.
|
||||||
def trace_event(name: str, rid: str, ts: Optional[int] = None):
|
def trace_event(name: str, rid: str, ts: Optional[int] = None):
|
||||||
if not tracing_enabled or rid not in reqs_context:
|
if not tracing_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
rid = str(rid)
|
rid = str(rid)
|
||||||
|
if rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
pid = threading.get_native_id()
|
pid = threading.get_native_id()
|
||||||
if pid not in reqs_context[rid].threads_context:
|
if pid not in reqs_context[rid].threads_context:
|
||||||
return
|
return
|
||||||
@@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
|
|||||||
|
|
||||||
# Add attrs to the current slice on the same thread with the same rid.
|
# Add attrs to the current slice on the same thread with the same rid.
|
||||||
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
|
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
|
||||||
if not tracing_enabled or rid not in reqs_context:
|
if not tracing_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
rid = str(rid)
|
rid = str(rid)
|
||||||
|
if rid not in reqs_context:
|
||||||
|
return
|
||||||
|
|
||||||
pid = threading.get_native_id()
|
pid = threading.get_native_id()
|
||||||
if pid not in reqs_context[rid].threads_context:
|
if pid not in reqs_context[rid].threads_context:
|
||||||
return
|
return
|
||||||
@@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
|
|||||||
|
|
||||||
slice_info = thread_context.cur_slice_stack[-1]
|
slice_info = thread_context.cur_slice_stack[-1]
|
||||||
slice_info.span.set_attributes(attrs)
|
slice_info.span.set_attributes(attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def trace_slice_batch(
|
||||||
|
name: str,
|
||||||
|
reqs: List[Req],
|
||||||
|
):
|
||||||
|
for req in reqs:
|
||||||
|
trace_slice(
|
||||||
|
name,
|
||||||
|
req.rid,
|
||||||
|
auto_next_anon=not req.finished(),
|
||||||
|
thread_finish_flag=req.finished(),
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user