Files
sglang/python/sglang/srt/managers/scheduler_metrics_mixin.py
2025-10-01 13:03:07 -07:00

415 lines
17 KiB
Python

from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_policy import PrefillAdder
from sglang.srt.managers.scheduler import Req, ScheduleBatch
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler
logger = logging.getLogger(__name__)
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class SchedulerMetricsMixin:
def init_metrics(
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
"tp_rank": tp_rank,
"pp_rank": pp_rank,
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
self.balance_meta = dp_balance_meta
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def log_prefill_stats(
self: Scheduler,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
running_bs_offline_batch: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
# TODO: generalize this for various memory pools
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"token usage: {token_usage:.2f}, "
f = (
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_usage_msg}"
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
logger.info(f)
if self.enable_metrics:
# Basics
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
self.stats.num_used_tokens = num_used
self.stats.token_usage = token_usage
if self.is_hybrid:
self.stats.swa_token_usage = swa_token_usage
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate
# Retract
self.stats.num_retracted_reqs = self.num_retracted_reqs
self.stats.num_paused_reqs = self.num_paused_reqs
self.num_retracted_reqs = self.num_paused_reqs = 0
# PD disaggregation
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
)
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
# Others
self.calculate_utilization()
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def log_decode_stats(
self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
num_running_reqs_offline_batch = 0
# TODO: generalize this for various memory pools
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_usage_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
cache_hit_rate = 0.0
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
# Basics
self.stats.num_running_reqs = num_running_reqs
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
self.stats.num_used_tokens = num_used
self.stats.token_usage = token_usage
if self.is_hybrid:
self.stats.swa_token_usage = swa_token_usage
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate
self.stats.spec_accept_length = spec_accept_length
# Retract
self.stats.num_retracted_reqs = self.num_retracted_reqs
self.stats.num_paused_reqs = self.num_paused_reqs
self.num_retracted_reqs = self.num_paused_reqs = 0
# PD disaggregation
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
)
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
# Others
self.calculate_utilization()
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def _emit_kv_metrics(self: Scheduler):
if not self.enable_kv_cache_events:
return
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
kv_metrics.kv_active_blocks = int(
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def _publish_kv_events(self: Scheduler):
if not self.enable_kv_cache_events:
return
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def maybe_update_dp_balance_data(
self: Scheduler, recv_req: TokenizedGenerateReqInput
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
def maybe_handle_dp_balance_data(self: Scheduler):
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
holding_tokens = self.get_load().num_tokens
new_recv_dp_balance_id_list, holding_token_list = (
self.gather_dp_balance_info(holding_tokens)
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
self.write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
def gather_dp_balance_info(
self: Scheduler, holding_tokens_list
) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
def calculate_utilization(self):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.utilization = -1
else:
if (
self.stats.max_running_requests_under_SLO is not None
and self.stats.max_running_requests_under_SLO > 0
):
self.stats.utilization = max(
self.stats.num_running_reqs
/ self.stats.max_running_requests_under_SLO,
self.stats.token_usage / 0.9,
)