Fix metrics and request tracing (TimeStats) (#11123)

This commit is contained in:
Lianmin Zheng
2025-10-01 13:03:07 -07:00
committed by GitHub
parent a28b394fba
commit 2d62af6be5
13 changed files with 461 additions and 392 deletions

View File

@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_event,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice,
trace_slice_batch,
trace_slice_end,
trace_slice_start,
)
@@ -263,6 +262,7 @@ class Scheduler(
server_args.enable_metrics_for_all_schedulers
)
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
self.enable_trace = server_args.enable_trace
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
@@ -899,10 +899,6 @@ class Scheduler(
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
@@ -924,10 +920,6 @@ class Scheduler(
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
@@ -1192,10 +1184,13 @@ class Scheduler(
src=self.tp_group.ranks[0],
)
for req in recv_reqs:
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)
if self.enable_trace:
for req in recv_reqs:
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)
return recv_reqs
@@ -1277,6 +1272,7 @@ class Scheduler(
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
disagg_mode=self.disaggregation_mode,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
@@ -1403,7 +1399,6 @@ class Scheduler(
req.set_finish_with_abort(error_msg)
if add_to_grammar_queue:
req.queue_time_start = time.perf_counter()
self.grammar_queue.append(req)
else:
self._add_request_to_queue(req)
@@ -1419,23 +1414,6 @@ class Scheduler(
for tokenized_req in recv_req:
self.handle_generate_request(tokenized_req)
def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache)
@@ -1449,19 +1427,27 @@ class Scheduler(
req.rid, req.last_host_node, new_input_tokens, last_hash
)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.NULL:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
req.time_stats.wait_queue_entry_time = time.perf_counter()
trace_slice_end("process req", req.rid, auto_next_anon=True)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
if not is_retracted:
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
else:
for req in reqs:
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
raise ValueError(f"Invalid {self.disaggregation_mode=}")
def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
@@ -1500,7 +1486,7 @@ class Scheduler(
direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: (
direction * item[1].priority,
item[1].queue_time_start,
item[1].time_stats.wait_queue_entry_time,
)
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = (
@@ -1902,14 +1888,14 @@ class Scheduler(
if self.enable_metrics:
# only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list:
req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING)
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list)
for req in adder.preempt_list:
self._add_request_to_queue(req)
if adder.new_chunked_req is not None:
assert self.chunked_req is None
@@ -1920,7 +1906,16 @@ class Scheduler(
# Print stats
if self.current_scheduler_metrics_enabled():
self.log_prefill_stats(adder, can_run_list, running_bs)
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
for req in can_run_list:
if req.time_stats.forward_entry_time == 0:
# Avoid update chunked request many times
req.time_stats.forward_entry_time = time.perf_counter()
if self.enable_metrics:
self.metrics_collector.observe_queue_time(
req.time_stats.get_queueing_time(),
)
# Create a new batch
new_batch = ScheduleBatch.init_new(
@@ -1975,19 +1970,25 @@ class Scheduler(
TEST_RETRACT and batch.batch_size() > 10
):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
num_retracted_reqs = len(retracted_reqs)
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
self.server_args
)
self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj(
AbortReq(req.rid, abort_reason=req.to_abort_message)
)
logger.info(
"KV cache pool is full. Retract requests. "
f"#retracted_reqs: {num_retracted_reqs}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
)
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
self.total_retracted_reqs += num_retracted_reqs
for req in retracted_reqs:
self._add_request_to_queue(req, is_retracted=True)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
@@ -2086,23 +2087,14 @@ class Scheduler(
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"decode loop",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
if self.enable_trace:
trace_slice_batch("decode loop", batch.reqs)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"prefill",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
if self.enable_trace:
trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
@@ -2261,12 +2253,13 @@ class Scheduler(
if req.finished(): # It is aborted by AbortReq
num_ready_reqs += 1
continue
req.grammar = req.grammar.result(timeout=0.03)
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
error_msg = f"Invalid grammar request: {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
num_ready_reqs += 1
except futures._base.TimeoutError:
req.grammar_wait_ct += 1
@@ -2298,9 +2291,8 @@ class Scheduler(
req.grammar = req.grammar.result()
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
error_msg = f"Invalid grammar request: {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
else:
num_ready_reqs_max = num_ready_reqs
num_timeout_reqs_max = num_timeout_reqs
@@ -2308,12 +2300,14 @@ class Scheduler(
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
req = self.grammar_queue[i]
req.grammar.cancel()
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
for req in self.grammar_queue[:num_ready_reqs]:
self._add_request_to_queue(req)
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
@@ -2795,17 +2789,11 @@ def run_scheduler_process(
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])
# Generate the prefix
# Generate the logger prefix
prefix = ""
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
dp_rank = int(os.environ["SGLANG_DP_RANK"])
if dp_rank is not None:
prefix += f" DP{dp_rank}"
if server_args.tp_size > 1:
@@ -2821,10 +2809,6 @@ def run_scheduler_process(
kill_itself_when_parent_died()
parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configure the logger
configure_logger(server_args, prefix=prefix)
suppress_other_loggers()
@@ -2832,6 +2816,15 @@ def run_scheduler_process(
# Set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])
# Set up tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
# Create a scheduler and run the event loop
try: