metrics: support customer buckets for prompt/generation_tokens_histogram (#9634)
This commit is contained in:
@@ -121,21 +121,23 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
|
|
||||||
## Logging
|
## Logging
|
||||||
|
|
||||||
| Arguments | Description | Defaults |
|
| Arguments | Description | Defaults |
|
||||||
|-----------|-------------|----------|
|
|---------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------|
|
||||||
| `--log-level` | The logging level of all loggers. | info |
|
| `--log-level` | The logging level of all loggers. | info |
|
||||||
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
|
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
|
||||||
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
|
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
|
||||||
| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
|
| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
|
||||||
| `--show-time-cost` | Show time cost of custom marks. | False |
|
| `--show-time-cost` | Show time cost of custom marks. | False |
|
||||||
| `--enable-metrics` | Enable log prometheus metrics. | False |
|
| `--enable-metrics` | Enable log prometheus metrics. | False |
|
||||||
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |
|
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |
|
||||||
| `--bucket-inter-token-latency` | The buckets of inter-token latency, specified as a list of floats. | None |
|
| `--bucket-inter-token-latency` | The buckets of inter-token latency, specified as a list of floats. | None |
|
||||||
| `--bucket-e2e-request-latency` | The buckets of end-to-end request latency, specified as a list of floats. | None |
|
| `--bucket-e2e-request-latency` | The buckets of end-to-end request latency, specified as a list of floats. | None |
|
||||||
| `--collect-tokens-histogram` | Collect prompt/generation tokens histogram. | False |
|
| `--collect-tokens-histogram` | Collect prompt/generation tokens histogram. | False |
|
||||||
| `--kv-events-config` | Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used. | None |
|
| `--kv-events-config` | Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used. | None |
|
||||||
| `--decode-log-interval` | The log interval of decode batch. | 40 |
|
| `--decode-log-interval` | The log interval of decode batch. | 40 |
|
||||||
| `--enable-request-time-stats-logging` | Enable per request time stats logging. | False |
|
| `--enable-request-time-stats-logging` | Enable per request time stats logging. | False |
|
||||||
|
| `--prompt-tokens-buckets` | The buckets rule of prompt tokens. Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets [984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> <value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500'). | None |
|
||||||
|
| `--generation-tokens-buckets` | The buckets rule of prompt tokens. Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets [984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> <value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500'). | None |
|
||||||
|
|
||||||
## API related
|
## API related
|
||||||
|
|
||||||
|
|||||||
@@ -329,6 +329,7 @@ class TokenizerManager:
|
|||||||
# Metrics
|
# Metrics
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
self.metrics_collector = TokenizerMetricsCollector(
|
self.metrics_collector = TokenizerMetricsCollector(
|
||||||
|
server_args=server_args,
|
||||||
labels={
|
labels={
|
||||||
"model_name": self.server_args.served_model_name,
|
"model_name": self.server_args.served_model_name,
|
||||||
# TODO: Add lora name/path in the future,
|
# TODO: Add lora name/path in the future,
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.metrics.utils import generate_buckets
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
|
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
|
||||||
@@ -309,6 +311,7 @@ class SchedulerMetricsCollector:
|
|||||||
class TokenizerMetricsCollector:
|
class TokenizerMetricsCollector:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
labels: Dict[str, str],
|
labels: Dict[str, str],
|
||||||
bucket_time_to_first_token: Optional[List[float]] = None,
|
bucket_time_to_first_token: Optional[List[float]] = None,
|
||||||
bucket_inter_token_latency: Optional[List[float]] = None,
|
bucket_inter_token_latency: Optional[List[float]] = None,
|
||||||
@@ -334,7 +337,7 @@ class TokenizerMetricsCollector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if collect_tokens_histogram:
|
if collect_tokens_histogram:
|
||||||
bucket_prompt_tokens = [
|
default_bucket_prompt_tokens = [
|
||||||
100,
|
100,
|
||||||
300,
|
300,
|
||||||
500,
|
500,
|
||||||
@@ -363,9 +366,11 @@ class TokenizerMetricsCollector:
|
|||||||
name="sglang:prompt_tokens_histogram",
|
name="sglang:prompt_tokens_histogram",
|
||||||
documentation="Histogram of prompt token length.",
|
documentation="Histogram of prompt token length.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=bucket_prompt_tokens,
|
buckets=generate_buckets(
|
||||||
|
server_args.prompt_tokens_buckets, default_bucket_prompt_tokens
|
||||||
|
),
|
||||||
)
|
)
|
||||||
bucket_generation_tokens = [
|
default_bucket_generation_tokens = [
|
||||||
100,
|
100,
|
||||||
300,
|
300,
|
||||||
500,
|
500,
|
||||||
@@ -390,7 +395,10 @@ class TokenizerMetricsCollector:
|
|||||||
name="sglang:generation_tokens_histogram",
|
name="sglang:generation_tokens_histogram",
|
||||||
documentation="Histogram of generation token length.",
|
documentation="Histogram of generation token length.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=bucket_generation_tokens,
|
buckets=generate_buckets(
|
||||||
|
server_args.generation_tokens_buckets,
|
||||||
|
default_bucket_generation_tokens,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cached_tokens_total = Counter(
|
self.cached_tokens_total = Counter(
|
||||||
|
|||||||
48
python/sglang/srt/metrics/utils.py
Normal file
48
python/sglang/srt/metrics/utils.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2023-2025 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for Prometheus Metrics."""
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def two_sides_exponential_buckets(
|
||||||
|
middle: float, base: float, count: int
|
||||||
|
) -> List[float]:
|
||||||
|
buckets = []
|
||||||
|
half_count = math.ceil(count / 2)
|
||||||
|
distance = 1
|
||||||
|
buckets.append(middle)
|
||||||
|
for i in range(half_count):
|
||||||
|
distance *= base
|
||||||
|
buckets.append(middle + distance)
|
||||||
|
buckets.append(max(0, middle - distance))
|
||||||
|
return sorted(set(buckets))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_buckets(
|
||||||
|
buckets_rule: List[str], default_buckets: List[float]
|
||||||
|
) -> List[float]:
|
||||||
|
if not buckets_rule:
|
||||||
|
buckets_rule = ["default"]
|
||||||
|
|
||||||
|
assert len(buckets_rule) > 0
|
||||||
|
rule = buckets_rule[0]
|
||||||
|
if rule == "tse":
|
||||||
|
middle, base, count = buckets_rule[1:]
|
||||||
|
assert float(base) > 1.0, "Base must be greater than 1.0"
|
||||||
|
return two_sides_exponential_buckets(float(middle), float(base), int(count))
|
||||||
|
if rule == "default":
|
||||||
|
return sorted(set(default_buckets))
|
||||||
|
assert rule == "customer"
|
||||||
|
return sorted(set([float(x) for x in buckets_rule[1:]]))
|
||||||
@@ -195,6 +195,8 @@ class ServerArgs:
|
|||||||
bucket_inter_token_latency: Optional[List[float]] = None
|
bucket_inter_token_latency: Optional[List[float]] = None
|
||||||
bucket_e2e_request_latency: Optional[List[float]] = None
|
bucket_e2e_request_latency: Optional[List[float]] = None
|
||||||
collect_tokens_histogram: bool = False
|
collect_tokens_histogram: bool = False
|
||||||
|
prompt_tokens_buckets: Optional[List[str]] = None
|
||||||
|
generation_tokens_buckets: Optional[List[str]] = None
|
||||||
decode_log_interval: int = 40
|
decode_log_interval: int = 40
|
||||||
enable_request_time_stats_logging: bool = False
|
enable_request_time_stats_logging: bool = False
|
||||||
kv_events_config: Optional[str] = None
|
kv_events_config: Optional[str] = None
|
||||||
@@ -1234,6 +1236,26 @@ class ServerArgs:
|
|||||||
default=ServerArgs.collect_tokens_histogram,
|
default=ServerArgs.collect_tokens_histogram,
|
||||||
help="Collect prompt/generation tokens histogram.",
|
help="Collect prompt/generation tokens histogram.",
|
||||||
)
|
)
|
||||||
|
bucket_rule = (
|
||||||
|
"Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' "
|
||||||
|
"generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
|
||||||
|
"[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> "
|
||||||
|
"<value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-tokens-buckets",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=ServerArgs.prompt_tokens_buckets,
|
||||||
|
help=f"The buckets rule of prompt tokens. {bucket_rule}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--generation-tokens-buckets",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=ServerArgs.generation_tokens_buckets,
|
||||||
|
help=f"The buckets rule for generation tokens histogram. {bucket_rule}",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gc-warning-threshold-secs",
|
"--gc-warning-threshold-secs",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -2185,6 +2207,12 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Check multi tokenizer
|
# Check multi tokenizer
|
||||||
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
|
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
|
||||||
|
self.validate_buckets_rule(
|
||||||
|
"--prompt-tokens-buckets", self.prompt_tokens_buckets
|
||||||
|
)
|
||||||
|
self.validate_buckets_rule(
|
||||||
|
"--generation-tokens-buckets", self.generation_tokens_buckets
|
||||||
|
)
|
||||||
|
|
||||||
def check_lora_server_args(self):
|
def check_lora_server_args(self):
|
||||||
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
||||||
@@ -2277,6 +2305,54 @@ class ServerArgs:
|
|||||||
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]):
|
||||||
|
if not buckets_rule:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert len(buckets_rule) > 0, f"{arg_name} cannot be empty list"
|
||||||
|
rule = buckets_rule[0]
|
||||||
|
assert rule in [
|
||||||
|
"tse",
|
||||||
|
"default",
|
||||||
|
"customer",
|
||||||
|
], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
|
||||||
|
|
||||||
|
if rule == "tse":
|
||||||
|
assert (
|
||||||
|
len(buckets_rule) == 4
|
||||||
|
), f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}"
|
||||||
|
try:
|
||||||
|
middle = float(buckets_rule[1])
|
||||||
|
base = float(buckets_rule[2])
|
||||||
|
count = int(buckets_rule[3])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), f"{arg_name} TSE rule parameters must be: ['tse', <float:middle>, <float:base>, <int:count>]"
|
||||||
|
assert base > 1, f"{arg_name} TSE base must be larger than 1, got: {base}"
|
||||||
|
assert count > 0, f"{arg_name} TSE count must be positive, got: {count}"
|
||||||
|
assert middle > 0, f"{arg_name} TSE middle must be positive, got: {middle}"
|
||||||
|
|
||||||
|
elif rule == "default":
|
||||||
|
assert (
|
||||||
|
len(buckets_rule) == 1
|
||||||
|
), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
|
||||||
|
|
||||||
|
elif rule == "customer":
|
||||||
|
assert (
|
||||||
|
len(buckets_rule) >= 2
|
||||||
|
), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
|
||||||
|
try:
|
||||||
|
bucket_values = [float(x) for x in buckets_rule[1:]]
|
||||||
|
except ValueError:
|
||||||
|
assert False, f"{arg_name} customer rule bucket values must be numeric"
|
||||||
|
assert len(set(bucket_values)) == len(
|
||||||
|
bucket_values
|
||||||
|
), f"{arg_name} customer rule bucket values should not contain duplicates"
|
||||||
|
assert all(
|
||||||
|
val >= 0 for val in bucket_values
|
||||||
|
), f"{arg_name} customer rule bucket values should be non-negative"
|
||||||
|
|
||||||
def model_specific_adjustments(self):
|
def model_specific_adjustments(self):
|
||||||
hf_config = self.get_hf_config()
|
hf_config = self.get_hf_config()
|
||||||
model_arch = hf_config.architectures[0]
|
model_arch = hf_config.architectures[0]
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ suites = {
|
|||||||
TestFile("test_io_struct.py", 8),
|
TestFile("test_io_struct.py", 8),
|
||||||
TestFile("test_jinja_template_utils.py", 1),
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
|
TestFile("test_metrics_utils.py", 1),
|
||||||
TestFile("test_mla.py", 167),
|
TestFile("test_mla.py", 167),
|
||||||
TestFile("test_mla_deepseek_v3.py", 700),
|
TestFile("test_mla_deepseek_v3.py", 700),
|
||||||
TestFile("test_mla_int8_deepseek_v3.py", 429),
|
TestFile("test_mla_int8_deepseek_v3.py", 429),
|
||||||
@@ -214,6 +215,7 @@ suite_amd = {
|
|||||||
TestFile("test_io_struct.py", 8),
|
TestFile("test_io_struct.py", 8),
|
||||||
TestFile("test_jinja_template_utils.py", 1),
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
|
TestFile("test_metrics_utils.py", 1),
|
||||||
TestFile("test_mla.py", 242),
|
TestFile("test_mla.py", 242),
|
||||||
TestFile("test_mla_deepseek_v3.py", 221),
|
TestFile("test_mla_deepseek_v3.py", 221),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
|
|||||||
137
test/srt/test_metrics_utils.py
Normal file
137
test/srt/test_metrics_utils.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from sglang.srt.metrics.utils import generate_buckets, two_sides_exponential_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsUtils(unittest.TestCase):
|
||||||
|
"""Test cases for metrics utility functions."""
|
||||||
|
|
||||||
|
def test_two_sides_exponential_buckets_basic(self):
|
||||||
|
"""Test basic functionality of two_sides_exponential_buckets."""
|
||||||
|
# Test with simple parameters
|
||||||
|
count = 5
|
||||||
|
buckets = two_sides_exponential_buckets(middle=10.0, base=2.0, count=count)
|
||||||
|
|
||||||
|
# Should contain the middle value
|
||||||
|
self.assertIn(10.0, buckets)
|
||||||
|
|
||||||
|
# Should be sorted
|
||||||
|
self.assertEqual(buckets, sorted(buckets))
|
||||||
|
|
||||||
|
# Should have unique values (no duplicates)
|
||||||
|
self.assertEqual(len(buckets), len(set(buckets)))
|
||||||
|
|
||||||
|
# Should have reasonable number of buckets (not exactly count due to ceiling and deduplication)
|
||||||
|
self.assertGreaterEqual(len(buckets), 3)
|
||||||
|
self.assertLessEqual(len(buckets), count + 2)
|
||||||
|
|
||||||
|
def test_two_sides_exponential_buckets_specific_values(self):
|
||||||
|
"""Test specific values for two_sides_exponential_buckets."""
|
||||||
|
buckets = two_sides_exponential_buckets(middle=100.0, base=2.0, count=4)
|
||||||
|
expected_values = [96.0, 98.0, 100.0, 102.0, 104.0]
|
||||||
|
self.assertEqual(buckets, expected_values)
|
||||||
|
|
||||||
|
def test_two_sides_exponential_buckets_negative_values(self):
|
||||||
|
"""Test two_sides_exponential_buckets with values that could go negative."""
|
||||||
|
buckets = two_sides_exponential_buckets(middle=5.0, base=3.0, count=4)
|
||||||
|
|
||||||
|
# Should not contain negative values (max(0, middle - distance))
|
||||||
|
for bucket in buckets:
|
||||||
|
self.assertGreaterEqual(bucket, 0.0)
|
||||||
|
|
||||||
|
# Should contain the middle value
|
||||||
|
self.assertIn(5.0, buckets)
|
||||||
|
|
||||||
|
def test_two_sides_exponential_buckets_edge_cases(self):
|
||||||
|
"""Test edge cases for two_sides_exponential_buckets."""
|
||||||
|
# Count = 1
|
||||||
|
buckets = two_sides_exponential_buckets(middle=10.0, base=2.0, count=1)
|
||||||
|
self.assertIn(10.0, buckets)
|
||||||
|
|
||||||
|
# Very small middle value
|
||||||
|
buckets = two_sides_exponential_buckets(middle=0.1, base=2.0, count=2)
|
||||||
|
self.assertIn(0.1, buckets)
|
||||||
|
for bucket in buckets:
|
||||||
|
self.assertGreaterEqual(bucket, 0.0)
|
||||||
|
|
||||||
|
def test_generate_buckets_default(self):
|
||||||
|
"""Test generate_buckets with default rule."""
|
||||||
|
default_buckets = [1.0, 5.0, 10.0, 50.0, 100.0]
|
||||||
|
|
||||||
|
# Test with "default" rule
|
||||||
|
result = generate_buckets(["default"], default_buckets)
|
||||||
|
self.assertEqual(result, default_buckets)
|
||||||
|
|
||||||
|
# Test with None (should default to "default")
|
||||||
|
result = generate_buckets(None, default_buckets)
|
||||||
|
self.assertEqual(result, default_buckets)
|
||||||
|
|
||||||
|
# Test with empty (should default to "default")
|
||||||
|
result = generate_buckets(None, default_buckets)
|
||||||
|
self.assertEqual(result, default_buckets)
|
||||||
|
|
||||||
|
def test_generate_buckets_tse(self):
|
||||||
|
"""Test generate_buckets with tse (two sides exponential) rule."""
|
||||||
|
default_buckets = [1.0, 5.0, 10.0]
|
||||||
|
|
||||||
|
# Test with "tse" rule
|
||||||
|
result = generate_buckets(["tse", "10", "2.0", "4"], default_buckets)
|
||||||
|
|
||||||
|
# Should return the same as calling two_sides_exponential_buckets directly
|
||||||
|
expected = two_sides_exponential_buckets(10.0, 2.0, 4)
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_generate_buckets_customer(self):
|
||||||
|
"""Test generate_buckets with customer rule."""
|
||||||
|
default_buckets = [1.0, 5.0, 10.0]
|
||||||
|
|
||||||
|
# Test with "customer" rule
|
||||||
|
result = generate_buckets(
|
||||||
|
["customer", "1.5", "3.2", "7.8", "15.6"], default_buckets
|
||||||
|
)
|
||||||
|
expected = [1.5, 3.2, 7.8, 15.6]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_generate_buckets_customer_with_integers(self):
|
||||||
|
"""Test generate_buckets with customer rule using integer strings."""
|
||||||
|
default_buckets = [1.0, 5.0, 10.0]
|
||||||
|
|
||||||
|
# Test with integer strings
|
||||||
|
result = generate_buckets(["customer", "1", "5", "10", "50"], default_buckets)
|
||||||
|
expected = [1.0, 5.0, 10.0, 50.0]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_generate_buckets_preserves_order_and_type(self):
|
||||||
|
"""Test that generate_buckets preserves order and returns floats."""
|
||||||
|
default_buckets = [1, 5, 10, 50, 100] # integers
|
||||||
|
|
||||||
|
# Test default rule
|
||||||
|
result = generate_buckets(["default"], default_buckets)
|
||||||
|
self.assertEqual(result, default_buckets)
|
||||||
|
self.assertIsInstance(result, list)
|
||||||
|
|
||||||
|
# Test customer rule with proper float conversion
|
||||||
|
result = generate_buckets(
|
||||||
|
["customer", "100", "50", "10", "5", "1"], default_buckets
|
||||||
|
)
|
||||||
|
expected = [1.0, 5.0, 10.0, 50.0, 100.0]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
# All values should be floats
|
||||||
|
for value in result:
|
||||||
|
self.assertIsInstance(value, float)
|
||||||
|
|
||||||
|
def test_integration_tse_through_generate_buckets(self):
|
||||||
|
"""Test integration of TSE buckets through generate_buckets function."""
|
||||||
|
default_buckets = [1.0, 10.0, 100.0]
|
||||||
|
|
||||||
|
# Generate buckets using both methods
|
||||||
|
direct_result = two_sides_exponential_buckets(50.0, 1.5, 6)
|
||||||
|
indirect_result = generate_buckets(["tse", "50.0", "1.5", "6"], default_buckets)
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
self.assertEqual(direct_result, indirect_result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user