metrics: support customer buckets for prompt/generation_tokens_histogram (#9634)

This commit is contained in:
Yingchun Lai
2025-09-04 22:22:08 +08:00
committed by GitHub
parent 75ee00112d
commit b32ab0705e
7 changed files with 293 additions and 19 deletions

View File

@@ -329,6 +329,7 @@ class TokenizerManager:
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args,
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,

View File

@@ -18,6 +18,8 @@ from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union
from sglang.srt.metrics.utils import generate_buckets
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
@@ -309,6 +311,7 @@ class SchedulerMetricsCollector:
class TokenizerMetricsCollector:
def __init__(
self,
server_args: ServerArgs,
labels: Dict[str, str],
bucket_time_to_first_token: Optional[List[float]] = None,
bucket_inter_token_latency: Optional[List[float]] = None,
@@ -334,7 +337,7 @@ class TokenizerMetricsCollector:
)
if collect_tokens_histogram:
bucket_prompt_tokens = [
default_bucket_prompt_tokens = [
100,
300,
500,
@@ -363,9 +366,11 @@ class TokenizerMetricsCollector:
name="sglang:prompt_tokens_histogram",
documentation="Histogram of prompt token length.",
labelnames=labels.keys(),
buckets=bucket_prompt_tokens,
buckets=generate_buckets(
server_args.prompt_tokens_buckets, default_bucket_prompt_tokens
),
)
bucket_generation_tokens = [
default_bucket_generation_tokens = [
100,
300,
500,
@@ -390,7 +395,10 @@ class TokenizerMetricsCollector:
name="sglang:generation_tokens_histogram",
documentation="Histogram of generation token length.",
labelnames=labels.keys(),
buckets=bucket_generation_tokens,
buckets=generate_buckets(
server_args.generation_tokens_buckets,
default_bucket_generation_tokens,
),
)
self.cached_tokens_total = Counter(

View File

@@ -0,0 +1,48 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Prometheus Metrics."""
import math
from typing import List
def two_sides_exponential_buckets(
middle: float, base: float, count: int
) -> List[float]:
buckets = []
half_count = math.ceil(count / 2)
distance = 1
buckets.append(middle)
for i in range(half_count):
distance *= base
buckets.append(middle + distance)
buckets.append(max(0, middle - distance))
return sorted(set(buckets))
def generate_buckets(
buckets_rule: List[str], default_buckets: List[float]
) -> List[float]:
if not buckets_rule:
buckets_rule = ["default"]
assert len(buckets_rule) > 0
rule = buckets_rule[0]
if rule == "tse":
middle, base, count = buckets_rule[1:]
assert float(base) > 1.0, "Base must be greater than 1.0"
return two_sides_exponential_buckets(float(middle), float(base), int(count))
if rule == "default":
return sorted(set(default_buckets))
assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]]))

View File

@@ -195,6 +195,8 @@ class ServerArgs:
bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
collect_tokens_histogram: bool = False
prompt_tokens_buckets: Optional[List[str]] = None
generation_tokens_buckets: Optional[List[str]] = None
decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None
@@ -1234,6 +1236,26 @@ class ServerArgs:
default=ServerArgs.collect_tokens_histogram,
help="Collect prompt/generation tokens histogram.",
)
bucket_rule = (
"Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' "
"generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
"[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> "
"<value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
)
parser.add_argument(
"--prompt-tokens-buckets",
type=str,
nargs="+",
default=ServerArgs.prompt_tokens_buckets,
help=f"The buckets rule of prompt tokens. {bucket_rule}",
)
parser.add_argument(
"--generation-tokens-buckets",
type=str,
nargs="+",
default=ServerArgs.generation_tokens_buckets,
help=f"The buckets rule for generation tokens histogram. {bucket_rule}",
)
parser.add_argument(
"--gc-warning-threshold-secs",
type=float,
@@ -2185,6 +2207,12 @@ class ServerArgs:
# Check multi tokenizer
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
self.validate_buckets_rule(
"--prompt-tokens-buckets", self.prompt_tokens_buckets
)
self.validate_buckets_rule(
"--generation-tokens-buckets", self.generation_tokens_buckets
)
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
@@ -2277,6 +2305,54 @@ class ServerArgs:
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]):
if not buckets_rule:
return
assert len(buckets_rule) > 0, f"{arg_name} cannot be empty list"
rule = buckets_rule[0]
assert rule in [
"tse",
"default",
"customer",
], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
if rule == "tse":
assert (
len(buckets_rule) == 4
), f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}"
try:
middle = float(buckets_rule[1])
base = float(buckets_rule[2])
count = int(buckets_rule[3])
except (ValueError, IndexError):
assert (
False
), f"{arg_name} TSE rule parameters must be: ['tse', <float:middle>, <float:base>, <int:count>]"
assert base > 1, f"{arg_name} TSE base must be larger than 1, got: {base}"
assert count > 0, f"{arg_name} TSE count must be positive, got: {count}"
assert middle > 0, f"{arg_name} TSE middle must be positive, got: {middle}"
elif rule == "default":
assert (
len(buckets_rule) == 1
), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
elif rule == "customer":
assert (
len(buckets_rule) >= 2
), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
try:
bucket_values = [float(x) for x in buckets_rule[1:]]
except ValueError:
assert False, f"{arg_name} customer rule bucket values must be numeric"
assert len(set(bucket_values)) == len(
bucket_values
), f"{arg_name} customer rule bucket values should not contain duplicates"
assert all(
val >= 0 for val in bucket_values
), f"{arg_name} customer rule bucket values should be non-negative"
def model_specific_adjustments(self):
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]