metrics: support customer labels specified in request header (#10143)

This commit is contained in:
Yingchun Lai
2025-09-15 11:00:08 +08:00
committed by GitHub
parent 8f6a175803
commit fc2c3a3d8e
11 changed files with 126 additions and 28 deletions

View File

@@ -229,6 +229,9 @@ class CompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens") @field_validator("max_tokens")
@classmethod @classmethod
def validate_max_tokens_positive(cls, v): def validate_max_tokens_positive(cls, v):

View File

@@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -24,6 +25,14 @@ class OpenAIServingBase(ABC):
def __init__(self, tokenizer_manager: TokenizerManager): def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = (
set(
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
)
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None
)
async def handle_request( async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request self, request: OpenAIServingRequest, raw_request: Request
@@ -37,7 +46,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format # Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request( adapted_request, processed_request = self._convert_to_internal_request(
request request, raw_request
) )
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
@@ -81,6 +90,7 @@ class OpenAIServingBase(ABC):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: OpenAIServingRequest, request: OpenAIServingRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, OpenAIServingRequest]: ) -> tuple[GenerateReqInput, OpenAIServingRequest]:
"""Convert OpenAI request to internal format""" """Convert OpenAI request to internal format"""
pass pass
@@ -154,3 +164,32 @@ class OpenAIServingBase(ABC):
code=status_code, code=status_code,
) )
return json.dumps({"error": error.model_dump()}) return json.dumps({"error": error.model_dump()})
def extract_customer_labels(self, raw_request):
if (
not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
):
return None
customer_labels = None
header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
)
try:
raw_labels = (
json.loads(raw_request.headers.get(header))
if raw_request and raw_request.headers.get(header)
else None
)
except json.JSONDecodeError as e:
logger.exception(f"Error in request: {e}")
raw_labels = None
if isinstance(raw_labels, dict):
customer_labels = {
label: value
for label, value in raw_labels.items()
if label in self.allowed_custom_labels
}
return customer_labels

View File

@@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, ChatCompletionRequest]: ) -> tuple[GenerateReqInput, ChatCompletionRequest]:
reasoning_effort = ( reasoning_effort = (
request.chat_template_kwargs.pop("reasoning_effort", None) request.chat_template_kwargs.pop("reasoning_effort", None)
@@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase):
else: else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids} prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
image_data=processed_messages.image_data, image_data=processed_messages.image_data,
@@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
customer_labels=customer_labels,
) )
return adapted_request, request return adapted_request, request

View File

@@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, CompletionRequest]: ) -> tuple[GenerateReqInput, CompletionRequest]:
"""Convert OpenAI completion request to internal format""" """Convert OpenAI completion request to internal format"""
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
@@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
else: else:
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
sampling_params=sampling_params, sampling_params=sampling_params,
@@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
customer_labels=customer_labels,
) )
return adapted_request, request return adapted_request, request

View File

@@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, EmbeddingRequest]: ) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
"""Convert OpenAI embedding request to internal format""" """Convert OpenAI embedding request to internal format"""
prompt = request.input prompt = request.input

View File

@@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase):
return None return None
def _convert_to_internal_request( def _convert_to_internal_request(
self, request: V1RerankReqInput self,
request: V1RerankReqInput,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, V1RerankReqInput]: ) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
"""Convert OpenAI rerank request to internal embedding format""" """Convert OpenAI rerank request to internal embedding format"""
# Create pairs of [query, document] for each document # Create pairs of [query, document] for each document

View File

@@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: ScoringRequest, request: ScoringRequest,
raw_request: Request = None,
) -> tuple[ScoringRequest, ScoringRequest]: ) -> tuple[ScoringRequest, ScoringRequest]:
"""Convert OpenAI scoring request to internal format""" """Convert OpenAI scoring request to internal format"""
# For scoring, we pass the request directly as the tokenizer_manager # For scoring, we pass the request directly as the tokenizer_manager

View File

@@ -141,6 +141,9 @@ class GenerateReqInput:
# Image gen grpc migration # Image gen grpc migration
return_bytes: bool = False return_bytes: bool = False
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)

View File

@@ -306,12 +306,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# Metrics # Metrics
if self.enable_metrics: if self.enable_metrics:
labels = {
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
}
if server_args.tokenizer_metrics_allowed_customer_labels:
for label in server_args.tokenizer_metrics_allowed_customer_labels:
labels[label] = ""
self.metrics_collector = TokenizerMetricsCollector( self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args, server_args=server_args,
labels={ labels=labels,
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token, bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency, bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency, bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
@@ -1036,7 +1040,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return return
req = AbortReq(rid, abort_all) req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics: if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request() self.metrics_collector.observe_one_aborted_request()
@@ -1616,6 +1619,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else 0 else 0
) )
customer_labels = getattr(state.obj, "customer_labels", None)
labels = (
{**self.metrics_collector.labels, **customer_labels}
if customer_labels
else self.metrics_collector.labels
)
if ( if (
state.first_token_time == 0.0 state.first_token_time == 0.0
and self.disaggregation_mode != DisaggregationMode.PREFILL and self.disaggregation_mode != DisaggregationMode.PREFILL
@@ -1623,7 +1632,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.first_token_time = state.last_time = time.time() state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token( self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time labels, state.first_token_time - state.created_time
) )
else: else:
num_new_tokens = completion_tokens - state.last_completion_tokens num_new_tokens = completion_tokens - state.last_completion_tokens
@@ -1631,6 +1640,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
new_time = time.time() new_time = time.time()
interval = new_time - state.last_time interval = new_time - state.last_time
self.metrics_collector.observe_inter_token_latency( self.metrics_collector.observe_inter_token_latency(
labels,
interval, interval,
num_new_tokens, num_new_tokens,
) )
@@ -1645,6 +1655,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
or state.obj.sampling_params.get("structural_tag", None) or state.obj.sampling_params.get("structural_tag", None)
) )
self.metrics_collector.observe_one_finished_request( self.metrics_collector.observe_one_finished_request(
labels,
recv_obj.prompt_tokens[i], recv_obj.prompt_tokens[i],
completion_tokens, completion_tokens,
recv_obj.cached_tokens[i], recv_obj.cached_tokens[i],

View File

@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utilities for Prometheus Metrics Collection.""" """Utilities for Prometheus Metrics Collection."""
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@@ -812,36 +811,38 @@ class TokenizerMetricsCollector:
buckets=bucket_time_to_first_token, buckets=bucket_time_to_first_token,
) )
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def observe_one_finished_request( def observe_one_finished_request(
self, self,
labels: Dict[str, str],
prompt_tokens: int, prompt_tokens: int,
generation_tokens: int, generation_tokens: int,
cached_tokens: int, cached_tokens: int,
e2e_latency: float, e2e_latency: float,
has_grammar: bool, has_grammar: bool,
): ):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.generation_tokens_total.labels(**labels).inc(generation_tokens)
if cached_tokens > 0: if cached_tokens > 0:
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) self.cached_tokens_total.labels(**labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1) self.num_requests_total.labels(**labels).inc(1)
if has_grammar: if has_grammar:
self.num_so_requests_total.labels(**self.labels).inc(1) self.num_so_requests_total.labels(**labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) self.histogram_e2e_request_latency.labels(**labels).observe(float(e2e_latency))
if self.collect_tokens_histogram: if self.collect_tokens_histogram:
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens) self.prompt_tokens_histogram.labels(**labels).observe(float(prompt_tokens))
self._log_histogram(self.generation_tokens_histogram, generation_tokens) self.generation_tokens_histogram.labels(**labels).observe(
float(generation_tokens)
)
def observe_time_to_first_token(self, value: float, label: str = ""): def observe_time_to_first_token(
if label == "batch": self, labels: Dict[str, str], value: float, type: str = ""
self.histogram_time_to_first_token_offline_batch.labels( ):
**self.labels if type == "batch":
).observe(value) self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
value
)
else: else:
self.histogram_time_to_first_token.labels(**self.labels).observe(value) self.histogram_time_to_first_token.labels(**labels).observe(value)
def check_time_to_first_token_straggler(self, value: float) -> bool: def check_time_to_first_token_straggler(self, value: float) -> bool:
his = self.histogram_time_to_first_token.labels(**self.labels) his = self.histogram_time_to_first_token.labels(**self.labels)
@@ -856,12 +857,14 @@ class TokenizerMetricsCollector:
return value >= his._upper_bounds[i] return value >= his._upper_bounds[i]
return False return False
def observe_inter_token_latency(self, internval: float, num_new_tokens: int): def observe_inter_token_latency(
self, labels: Dict[str, str], internval: float, num_new_tokens: int
):
adjusted_interval = internval / num_new_tokens adjusted_interval = internval / num_new_tokens
# A faster version of the Histogram::observe which observes multiple values at the same time. # A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639 # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**self.labels) his = self.histogram_inter_token_latency_seconds.labels(**labels)
his._sum.inc(internval) his._sum.inc(internval)
for i, bound in enumerate(his._upper_bounds): for i, bound in enumerate(his._upper_bounds):

View File

@@ -205,6 +205,8 @@ class ServerArgs:
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False enable_metrics_for_all_schedulers: bool = False
tokenizer_metrics_custom_labels_header: str = "x-customer-labels"
tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None
bucket_time_to_first_token: Optional[List[float]] = None bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None bucket_e2e_request_latency: Optional[List[float]] = None
@@ -911,6 +913,14 @@ class ServerArgs:
"and cannot be used at the same time. Please use only one of them." "and cannot be used at the same time. Please use only one of them."
) )
if (
not self.tokenizer_metrics_custom_labels_header
and self.tokenizer_metrics_allowed_customer_labels
):
raise ValueError(
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
)
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer # Model and tokenizer
@@ -1324,6 +1334,21 @@ class ServerArgs:
"to record request metrics separately. This is especially useful when dp_attention is enabled, as " "to record request metrics separately. This is especially useful when dp_attention is enabled, as "
"otherwise all metrics appear to come from TP 0.", "otherwise all metrics appear to come from TP 0.",
) )
parser.add_argument(
"--tokenizer-metrics-custom-labels-header",
type=str,
default=ServerArgs.tokenizer_metrics_custom_labels_header,
help="Specify the HTTP header for passing customer labels for tokenizer metrics.",
)
parser.add_argument(
"--tokenizer-metrics-allowed-customer-labels",
type=str,
nargs="+",
default=ServerArgs.tokenizer_metrics_allowed_customer_labels,
help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in "
"'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': "
"'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.",
)
parser.add_argument( parser.add_argument(
"--bucket-time-to-first-token", "--bucket-time-to-first-token",
type=float, type=float,