metrics: support customer labels specified in request header (#10143)
This commit is contained in:
@@ -229,6 +229,9 @@ class CompletionRequest(BaseModel):
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
|
||||
# For customer metric labels
|
||||
customer_labels: Optional[Dict[str, str]] = None
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens_positive(cls, v):
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -24,6 +25,14 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
def __init__(self, tokenizer_manager: TokenizerManager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
self.allowed_custom_labels = (
|
||||
set(
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||
)
|
||||
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
|
||||
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||
else None
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self, request: OpenAIServingRequest, raw_request: Request
|
||||
@@ -37,7 +46,7 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
# Convert to internal format
|
||||
adapted_request, processed_request = self._convert_to_internal_request(
|
||||
request
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||
@@ -81,6 +90,7 @@ class OpenAIServingBase(ABC):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: OpenAIServingRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
|
||||
"""Convert OpenAI request to internal format"""
|
||||
pass
|
||||
@@ -154,3 +164,32 @@ class OpenAIServingBase(ABC):
|
||||
code=status_code,
|
||||
)
|
||||
return json.dumps({"error": error.model_dump()})
|
||||
|
||||
def extract_customer_labels(self, raw_request):
|
||||
if (
|
||||
not self.allowed_custom_labels
|
||||
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||
):
|
||||
return None
|
||||
|
||||
customer_labels = None
|
||||
header = (
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||
)
|
||||
try:
|
||||
raw_labels = (
|
||||
json.loads(raw_request.headers.get(header))
|
||||
if raw_request and raw_request.headers.get(header)
|
||||
else None
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.exception(f"Error in request: {e}")
|
||||
raw_labels = None
|
||||
|
||||
if isinstance(raw_labels, dict):
|
||||
customer_labels = {
|
||||
label: value
|
||||
for label, value in raw_labels.items()
|
||||
if label in self.allowed_custom_labels
|
||||
}
|
||||
return customer_labels
|
||||
|
||||
@@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[GenerateReqInput, ChatCompletionRequest]:
|
||||
reasoning_effort = (
|
||||
request.chat_template_kwargs.pop("reasoning_effort", None)
|
||||
@@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
|
||||
# Extract customer labels from raw request headers
|
||||
customer_labels = self.extract_customer_labels(raw_request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
@@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[GenerateReqInput, CompletionRequest]:
|
||||
"""Convert OpenAI completion request to internal format"""
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
@@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
|
||||
# Extract customer labels from raw request headers
|
||||
customer_labels = self.extract_customer_labels(raw_request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params,
|
||||
@@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
|
||||
"""Convert OpenAI embedding request to internal format"""
|
||||
prompt = request.input
|
||||
|
||||
@@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase):
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self, request: V1RerankReqInput
|
||||
self,
|
||||
request: V1RerankReqInput,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
|
||||
"""Convert OpenAI rerank request to internal embedding format"""
|
||||
# Create pairs of [query, document] for each document
|
||||
|
||||
@@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: ScoringRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[ScoringRequest, ScoringRequest]:
|
||||
"""Convert OpenAI scoring request to internal format"""
|
||||
# For scoring, we pass the request directly as the tokenizer_manager
|
||||
|
||||
@@ -141,6 +141,9 @@ class GenerateReqInput:
|
||||
# Image gen grpc migration
|
||||
return_bytes: bool = False
|
||||
|
||||
# For customer metric labels
|
||||
customer_labels: Optional[Dict[str, str]] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
|
||||
@@ -306,12 +306,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
labels = {
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
}
|
||||
if server_args.tokenizer_metrics_allowed_customer_labels:
|
||||
for label in server_args.tokenizer_metrics_allowed_customer_labels:
|
||||
labels[label] = ""
|
||||
self.metrics_collector = TokenizerMetricsCollector(
|
||||
server_args=server_args,
|
||||
labels={
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
labels=labels,
|
||||
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
||||
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
||||
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
||||
@@ -1036,7 +1040,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
return
|
||||
req = AbortReq(rid, abort_all)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector.observe_one_aborted_request()
|
||||
|
||||
@@ -1616,6 +1619,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
else 0
|
||||
)
|
||||
|
||||
customer_labels = getattr(state.obj, "customer_labels", None)
|
||||
labels = (
|
||||
{**self.metrics_collector.labels, **customer_labels}
|
||||
if customer_labels
|
||||
else self.metrics_collector.labels
|
||||
)
|
||||
if (
|
||||
state.first_token_time == 0.0
|
||||
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
||||
@@ -1623,7 +1632,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
state.first_token_time = state.last_time = time.time()
|
||||
state.last_completion_tokens = completion_tokens
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
labels, state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
num_new_tokens = completion_tokens - state.last_completion_tokens
|
||||
@@ -1631,6 +1640,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
new_time = time.time()
|
||||
interval = new_time - state.last_time
|
||||
self.metrics_collector.observe_inter_token_latency(
|
||||
labels,
|
||||
interval,
|
||||
num_new_tokens,
|
||||
)
|
||||
@@ -1645,6 +1655,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
or state.obj.sampling_params.get("structural_tag", None)
|
||||
)
|
||||
self.metrics_collector.observe_one_finished_request(
|
||||
labels,
|
||||
recv_obj.prompt_tokens[i],
|
||||
completion_tokens,
|
||||
recv_obj.cached_tokens[i],
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for Prometheus Metrics Collection."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@@ -812,36 +811,38 @@ class TokenizerMetricsCollector:
|
||||
buckets=bucket_time_to_first_token,
|
||||
)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
histogram.labels(**self.labels).observe(data)
|
||||
|
||||
def observe_one_finished_request(
|
||||
self,
|
||||
labels: Dict[str, str],
|
||||
prompt_tokens: int,
|
||||
generation_tokens: int,
|
||||
cached_tokens: int,
|
||||
e2e_latency: float,
|
||||
has_grammar: bool,
|
||||
):
|
||||
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
||||
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
||||
self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
|
||||
self.generation_tokens_total.labels(**labels).inc(generation_tokens)
|
||||
if cached_tokens > 0:
|
||||
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
||||
self.num_requests_total.labels(**self.labels).inc(1)
|
||||
self.cached_tokens_total.labels(**labels).inc(cached_tokens)
|
||||
self.num_requests_total.labels(**labels).inc(1)
|
||||
if has_grammar:
|
||||
self.num_so_requests_total.labels(**self.labels).inc(1)
|
||||
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
||||
self.num_so_requests_total.labels(**labels).inc(1)
|
||||
self.histogram_e2e_request_latency.labels(**labels).observe(float(e2e_latency))
|
||||
if self.collect_tokens_histogram:
|
||||
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
|
||||
self._log_histogram(self.generation_tokens_histogram, generation_tokens)
|
||||
self.prompt_tokens_histogram.labels(**labels).observe(float(prompt_tokens))
|
||||
self.generation_tokens_histogram.labels(**labels).observe(
|
||||
float(generation_tokens)
|
||||
)
|
||||
|
||||
def observe_time_to_first_token(self, value: float, label: str = ""):
|
||||
if label == "batch":
|
||||
self.histogram_time_to_first_token_offline_batch.labels(
|
||||
**self.labels
|
||||
).observe(value)
|
||||
def observe_time_to_first_token(
|
||||
self, labels: Dict[str, str], value: float, type: str = ""
|
||||
):
|
||||
if type == "batch":
|
||||
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
|
||||
value
|
||||
)
|
||||
else:
|
||||
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
||||
self.histogram_time_to_first_token.labels(**labels).observe(value)
|
||||
|
||||
def check_time_to_first_token_straggler(self, value: float) -> bool:
|
||||
his = self.histogram_time_to_first_token.labels(**self.labels)
|
||||
@@ -856,12 +857,14 @@ class TokenizerMetricsCollector:
|
||||
return value >= his._upper_bounds[i]
|
||||
return False
|
||||
|
||||
def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
|
||||
def observe_inter_token_latency(
|
||||
self, labels: Dict[str, str], internval: float, num_new_tokens: int
|
||||
):
|
||||
adjusted_interval = internval / num_new_tokens
|
||||
|
||||
# A faster version of the Histogram::observe which observes multiple values at the same time.
|
||||
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
|
||||
his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
|
||||
his = self.histogram_inter_token_latency_seconds.labels(**labels)
|
||||
his._sum.inc(internval)
|
||||
|
||||
for i, bound in enumerate(his._upper_bounds):
|
||||
|
||||
@@ -205,6 +205,8 @@ class ServerArgs:
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
enable_metrics_for_all_schedulers: bool = False
|
||||
tokenizer_metrics_custom_labels_header: str = "x-customer-labels"
|
||||
tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None
|
||||
bucket_time_to_first_token: Optional[List[float]] = None
|
||||
bucket_inter_token_latency: Optional[List[float]] = None
|
||||
bucket_e2e_request_latency: Optional[List[float]] = None
|
||||
@@ -911,6 +913,14 @@ class ServerArgs:
|
||||
"and cannot be used at the same time. Please use only one of them."
|
||||
)
|
||||
|
||||
if (
|
||||
not self.tokenizer_metrics_custom_labels_header
|
||||
and self.tokenizer_metrics_allowed_customer_labels
|
||||
):
|
||||
raise ValueError(
|
||||
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and tokenizer
|
||||
@@ -1324,6 +1334,21 @@ class ServerArgs:
|
||||
"to record request metrics separately. This is especially useful when dp_attention is enabled, as "
|
||||
"otherwise all metrics appear to come from TP 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-metrics-custom-labels-header",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_metrics_custom_labels_header,
|
||||
help="Specify the HTTP header for passing customer labels for tokenizer metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-metrics-allowed-customer-labels",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=ServerArgs.tokenizer_metrics_allowed_customer_labels,
|
||||
help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in "
|
||||
"'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': "
|
||||
"'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket-time-to-first-token",
|
||||
type=float,
|
||||
|
||||
Reference in New Issue
Block a user