diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index f73e67d0b..7fed16703 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -229,6 +229,9 @@ class CompletionRequest(BaseModel): # For request id rid: Optional[Union[List[str], str]] = None + # For customer metric labels + customer_labels: Optional[Dict[str, str]] = None + @field_validator("max_tokens") @classmethod def validate_max_tokens_positive(cls, v): diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 28b317e6d..5bc505108 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -24,6 +25,14 @@ class OpenAIServingBase(ABC): def __init__(self, tokenizer_manager: TokenizerManager): self.tokenizer_manager = tokenizer_manager + self.allowed_custom_labels = ( + set( + self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels + ) + if isinstance(self.tokenizer_manager.server_args, ServerArgs) + and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels + else None + ) async def handle_request( self, request: OpenAIServingRequest, raw_request: Request @@ -37,7 +46,7 @@ class OpenAIServingBase(ABC): # Convert to internal format adapted_request, processed_request = self._convert_to_internal_request( - request + request, raw_request ) # Note(Xinyuan): raw_request below is only used for detecting the connection of the client @@ -81,6 +90,7 @@ class OpenAIServingBase(ABC): def _convert_to_internal_request( self, request: OpenAIServingRequest, + raw_request: Request = None, ) -> tuple[GenerateReqInput, OpenAIServingRequest]: """Convert OpenAI request to internal format""" pass @@ -154,3 +164,32 @@ class OpenAIServingBase(ABC): code=status_code, ) return json.dumps({"error": error.model_dump()}) + + def extract_customer_labels(self, raw_request): + if ( + not self.allowed_custom_labels + or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header + ): + return None + + customer_labels = None + header = ( + self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header + ) + try: + raw_labels = ( + json.loads(raw_request.headers.get(header)) + if raw_request and raw_request.headers.get(header) + else None + ) + except json.JSONDecodeError as e: + logger.exception(f"Error in request: {e}") + raw_labels = None + + if isinstance(raw_labels, dict): + customer_labels = { + label: value + for label, value in raw_labels.items() + if label in self.allowed_custom_labels + } + return customer_labels diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index d67cbfde3..d132c7bed 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase): def _convert_to_internal_request( self, request: ChatCompletionRequest, + raw_request: Request = None, ) -> tuple[GenerateReqInput, ChatCompletionRequest]: reasoning_effort = ( request.chat_template_kwargs.pop("reasoning_effort", None) @@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase): else: prompt_kwargs = {"input_ids": processed_messages.prompt_ids} + # Extract customer labels from raw request headers + customer_labels = self.extract_customer_labels(raw_request) + adapted_request = GenerateReqInput( **prompt_kwargs, image_data=processed_messages.image_data, @@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase): bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, rid=request.rid, + customer_labels=customer_labels, ) return adapted_request, request diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 6fe02d325..68b4f97b4 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase): def _convert_to_internal_request( self, request: CompletionRequest, + raw_request: Request = None, ) -> tuple[GenerateReqInput, CompletionRequest]: """Convert OpenAI completion request to internal format""" # NOTE: with openai API, the prompt's logprobs are always not computed @@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase): else: prompt_kwargs = {"input_ids": prompt} + # Extract customer labels from raw request headers + customer_labels = self.extract_customer_labels(raw_request) + adapted_request = GenerateReqInput( **prompt_kwargs, sampling_params=sampling_params, @@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase): bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, rid=request.rid, + customer_labels=customer_labels, ) return adapted_request, request diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 63c4fc34a..6500915c1 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): def _convert_to_internal_request( self, request: EmbeddingRequest, + raw_request: Request = None, ) -> tuple[EmbeddingReqInput, EmbeddingRequest]: """Convert OpenAI embedding request to internal format""" prompt = request.input diff --git a/python/sglang/srt/entrypoints/openai/serving_rerank.py b/python/sglang/srt/entrypoints/openai/serving_rerank.py index b053c55b3..128215896 100644 --- a/python/sglang/srt/entrypoints/openai/serving_rerank.py +++ b/python/sglang/srt/entrypoints/openai/serving_rerank.py @@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase): return None def _convert_to_internal_request( - self, request: V1RerankReqInput + self, + request: V1RerankReqInput, + raw_request: Request = None, ) -> tuple[EmbeddingReqInput, V1RerankReqInput]: """Convert OpenAI rerank request to internal embedding format""" # Create pairs of [query, document] for each document diff --git a/python/sglang/srt/entrypoints/openai/serving_score.py b/python/sglang/srt/entrypoints/openai/serving_score.py index fc8ce5dca..19f788ad8 100644 --- a/python/sglang/srt/entrypoints/openai/serving_score.py +++ b/python/sglang/srt/entrypoints/openai/serving_score.py @@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase): def _convert_to_internal_request( self, request: ScoringRequest, + raw_request: Request = None, ) -> tuple[ScoringRequest, ScoringRequest]: """Convert OpenAI scoring request to internal format""" # For scoring, we pass the request directly as the tokenizer_manager diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index cf5406660..16b87e164 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -141,6 +141,9 @@ class GenerateReqInput: # Image gen grpc migration return_bytes: bool = False + # For customer metric labels + customer_labels: Optional[Dict[str, str]] = None + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 40f21b17d..5674bb475 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -306,12 +306,16 @@ class TokenizerManager(TokenizerCommunicatorMixin): # Metrics if self.enable_metrics: + labels = { + "model_name": self.server_args.served_model_name, + # TODO: Add lora name/path in the future, + } + if server_args.tokenizer_metrics_allowed_customer_labels: + for label in server_args.tokenizer_metrics_allowed_customer_labels: + labels[label] = "" self.metrics_collector = TokenizerMetricsCollector( server_args=server_args, - labels={ - "model_name": self.server_args.served_model_name, - # TODO: Add lora name/path in the future, - }, + labels=labels, bucket_time_to_first_token=self.server_args.bucket_time_to_first_token, bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency, bucket_inter_token_latency=self.server_args.bucket_inter_token_latency, @@ -1036,7 +1040,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): return req = AbortReq(rid, abort_all) self.send_to_scheduler.send_pyobj(req) - if self.enable_metrics: self.metrics_collector.observe_one_aborted_request() @@ -1616,6 +1619,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): else 0 ) + customer_labels = getattr(state.obj, "customer_labels", None) + labels = ( + {**self.metrics_collector.labels, **customer_labels} + if customer_labels + else self.metrics_collector.labels + ) if ( state.first_token_time == 0.0 and self.disaggregation_mode != DisaggregationMode.PREFILL @@ -1623,7 +1632,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): state.first_token_time = state.last_time = time.time() state.last_completion_tokens = completion_tokens self.metrics_collector.observe_time_to_first_token( - state.first_token_time - state.created_time + labels, state.first_token_time - state.created_time ) else: num_new_tokens = completion_tokens - state.last_completion_tokens @@ -1631,6 +1640,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): new_time = time.time() interval = new_time - state.last_time self.metrics_collector.observe_inter_token_latency( + labels, interval, num_new_tokens, ) @@ -1645,6 +1655,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): or state.obj.sampling_params.get("structural_tag", None) ) self.metrics_collector.observe_one_finished_request( + labels, recv_obj.prompt_tokens[i], completion_tokens, recv_obj.cached_tokens[i], diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index adc70e8cd..551d51184 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -12,7 +12,6 @@ # limitations under the License. # ============================================================================== """Utilities for Prometheus Metrics Collection.""" - import time from dataclasses import dataclass, field from enum import Enum @@ -812,36 +811,38 @@ class TokenizerMetricsCollector: buckets=bucket_time_to_first_token, ) - def _log_histogram(self, histogram, data: Union[int, float]) -> None: - histogram.labels(**self.labels).observe(data) - def observe_one_finished_request( self, + labels: Dict[str, str], prompt_tokens: int, generation_tokens: int, cached_tokens: int, e2e_latency: float, has_grammar: bool, ): - self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) - self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) + self.prompt_tokens_total.labels(**labels).inc(prompt_tokens) + self.generation_tokens_total.labels(**labels).inc(generation_tokens) if cached_tokens > 0: - self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) - self.num_requests_total.labels(**self.labels).inc(1) + self.cached_tokens_total.labels(**labels).inc(cached_tokens) + self.num_requests_total.labels(**labels).inc(1) if has_grammar: - self.num_so_requests_total.labels(**self.labels).inc(1) - self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) + self.num_so_requests_total.labels(**labels).inc(1) + self.histogram_e2e_request_latency.labels(**labels).observe(float(e2e_latency)) if self.collect_tokens_histogram: - self._log_histogram(self.prompt_tokens_histogram, prompt_tokens) - self._log_histogram(self.generation_tokens_histogram, generation_tokens) + self.prompt_tokens_histogram.labels(**labels).observe(float(prompt_tokens)) + self.generation_tokens_histogram.labels(**labels).observe( + float(generation_tokens) + ) - def observe_time_to_first_token(self, value: float, label: str = ""): - if label == "batch": - self.histogram_time_to_first_token_offline_batch.labels( - **self.labels - ).observe(value) + def observe_time_to_first_token( + self, labels: Dict[str, str], value: float, type: str = "" + ): + if type == "batch": + self.histogram_time_to_first_token_offline_batch.labels(**labels).observe( + value + ) else: - self.histogram_time_to_first_token.labels(**self.labels).observe(value) + self.histogram_time_to_first_token.labels(**labels).observe(value) def check_time_to_first_token_straggler(self, value: float) -> bool: his = self.histogram_time_to_first_token.labels(**self.labels) @@ -856,12 +857,14 @@ class TokenizerMetricsCollector: return value >= his._upper_bounds[i] return False - def observe_inter_token_latency(self, internval: float, num_new_tokens: int): + def observe_inter_token_latency( + self, labels: Dict[str, str], internval: float, num_new_tokens: int + ): adjusted_interval = internval / num_new_tokens # A faster version of the Histogram::observe which observes multiple values at the same time. # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639 - his = self.histogram_inter_token_latency_seconds.labels(**self.labels) + his = self.histogram_inter_token_latency_seconds.labels(**labels) his._sum.inc(internval) for i, bound in enumerate(his._upper_bounds): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a556febaa..b846be5a1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -205,6 +205,8 @@ class ServerArgs: show_time_cost: bool = False enable_metrics: bool = False enable_metrics_for_all_schedulers: bool = False + tokenizer_metrics_custom_labels_header: str = "x-customer-labels" + tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None bucket_time_to_first_token: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None bucket_e2e_request_latency: Optional[List[float]] = None @@ -911,6 +913,14 @@ class ServerArgs: "and cannot be used at the same time. Please use only one of them." ) + if ( + not self.tokenizer_metrics_custom_labels_header + and self.tokenizer_metrics_allowed_customer_labels + ): + raise ValueError( + "Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels." + ) + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and tokenizer @@ -1324,6 +1334,21 @@ class ServerArgs: "to record request metrics separately. This is especially useful when dp_attention is enabled, as " "otherwise all metrics appear to come from TP 0.", ) + parser.add_argument( + "--tokenizer-metrics-custom-labels-header", + type=str, + default=ServerArgs.tokenizer_metrics_custom_labels_header, + help="Specify the HTTP header for passing customer labels for tokenizer metrics.", + ) + parser.add_argument( + "--tokenizer-metrics-allowed-customer-labels", + type=str, + nargs="+", + default=ServerArgs.tokenizer_metrics_allowed_customer_labels, + help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in " + "'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': " + "'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.", + ) parser.add_argument( "--bucket-time-to-first-token", type=float,