metrics: support customer labels specified in request header (#10143)
This commit is contained in:
@@ -229,6 +229,9 @@ class CompletionRequest(BaseModel):
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
|
||||
# For customer metric labels
|
||||
customer_labels: Optional[Dict[str, str]] = None
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens_positive(cls, v):
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -24,6 +25,14 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
def __init__(self, tokenizer_manager: TokenizerManager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
self.allowed_custom_labels = (
|
||||
set(
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||
)
|
||||
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
|
||||
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||
else None
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self, request: OpenAIServingRequest, raw_request: Request
|
||||
@@ -37,7 +46,7 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
# Convert to internal format
|
||||
adapted_request, processed_request = self._convert_to_internal_request(
|
||||
request
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||
@@ -81,6 +90,7 @@ class OpenAIServingBase(ABC):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: OpenAIServingRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
|
||||
"""Convert OpenAI request to internal format"""
|
||||
pass
|
||||
@@ -154,3 +164,32 @@ class OpenAIServingBase(ABC):
|
||||
code=status_code,
|
||||
)
|
||||
return json.dumps({"error": error.model_dump()})
|
||||
|
||||
def extract_customer_labels(self, raw_request):
|
||||
if (
|
||||
not self.allowed_custom_labels
|
||||
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||
):
|
||||
return None
|
||||
|
||||
customer_labels = None
|
||||
header = (
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||
)
|
||||
try:
|
||||
raw_labels = (
|
||||
json.loads(raw_request.headers.get(header))
|
||||
if raw_request and raw_request.headers.get(header)
|
||||
else None
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.exception(f"Error in request: {e}")
|
||||
raw_labels = None
|
||||
|
||||
if isinstance(raw_labels, dict):
|
||||
customer_labels = {
|
||||
label: value
|
||||
for label, value in raw_labels.items()
|
||||
if label in self.allowed_custom_labels
|
||||
}
|
||||
return customer_labels
|
||||
|
||||
@@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[GenerateReqInput, ChatCompletionRequest]:
|
||||
reasoning_effort = (
|
||||
request.chat_template_kwargs.pop("reasoning_effort", None)
|
||||
@@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
|
||||
# Extract customer labels from raw request headers
|
||||
customer_labels = self.extract_customer_labels(raw_request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
@@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[GenerateReqInput, CompletionRequest]:
|
||||
"""Convert OpenAI completion request to internal format"""
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
@@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
|
||||
# Extract customer labels from raw request headers
|
||||
customer_labels = self.extract_customer_labels(raw_request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params,
|
||||
@@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
|
||||
"""Convert OpenAI embedding request to internal format"""
|
||||
prompt = request.input
|
||||
|
||||
@@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase):
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self, request: V1RerankReqInput
|
||||
self,
|
||||
request: V1RerankReqInput,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
|
||||
"""Convert OpenAI rerank request to internal embedding format"""
|
||||
# Create pairs of [query, document] for each document
|
||||
|
||||
@@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: ScoringRequest,
|
||||
raw_request: Request = None,
|
||||
) -> tuple[ScoringRequest, ScoringRequest]:
|
||||
"""Convert OpenAI scoring request to internal format"""
|
||||
# For scoring, we pass the request directly as the tokenizer_manager
|
||||
|
||||
Reference in New Issue
Block a user