from __future__ import annotations import json import logging import uuid from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional, Union import orjson from fastapi import HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: from sglang.srt.managers.tokenizer_manager import TokenizerManager logger = logging.getLogger(__name__) # Base class for specific endpoint handlers class OpenAIServingBase(ABC): """Abstract base class for OpenAI endpoint handlers""" def __init__(self, tokenizer_manager: TokenizerManager): self.tokenizer_manager = tokenizer_manager self.allowed_custom_labels = ( set( self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels ) if isinstance(self.tokenizer_manager.server_args, ServerArgs) and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels else None ) async def handle_request( self, request: OpenAIServingRequest, raw_request: Request ) -> Union[Any, StreamingResponse, ErrorResponse]: """Handle the specific request type with common pattern""" try: # Validate request error_msg = self._validate_request(request) if error_msg: return self.create_error_response(error_msg) # Convert to internal format adapted_request, processed_request = self._convert_to_internal_request( request, raw_request ) # Note(Xinyuan): raw_request below is only used for detecting the connection of the client if hasattr(request, "stream") and request.stream: return await self._handle_streaming_request( adapted_request, processed_request, raw_request ) else: return await self._handle_non_streaming_request( adapted_request, processed_request, raw_request ) except HTTPException as e: return self.create_error_response( message=e.detail, err_type=str(e.status_code), status_code=e.status_code ) except ValueError as e: return self.create_error_response( message=str(e), err_type="BadRequest", status_code=400, ) except Exception as e: logger.exception(f"Error in request: {e}") return self.create_error_response( message=f"Internal server error: {str(e)}", err_type="InternalServerError", status_code=500, ) @abstractmethod def _request_id_prefix(self) -> str: """Generate request ID based on request type""" pass def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]: """Generate request ID based on request type""" return None # TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError # Temporarily return None in this function until the rid logic is clear. if rid := getattr(request, "rid", None): return rid return f"{self._request_id_prefix()}{uuid.uuid4().hex}" def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]: """Compute the final extra_key by concatenating cache_salt and extra_key if both are provided.""" parts = [] for key in ["cache_salt", "extra_key"]: value = getattr(request, key, None) if value: if not isinstance(value, str): raise TypeError( f"Value of {key} must be a string, but got {type(value).__name__}" ) parts.append(value) return "".join(parts) if parts else None @abstractmethod def _convert_to_internal_request( self, request: OpenAIServingRequest, raw_request: Request = None, ) -> tuple[GenerateReqInput, OpenAIServingRequest]: """Convert OpenAI request to internal format""" pass async def _handle_streaming_request( self, adapted_request: GenerateReqInput, request: OpenAIServingRequest, raw_request: Request, ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]: """Handle streaming request Override this method in child classes that support streaming requests. """ return self.create_error_response( message=f"{self.__class__.__name__} does not support streaming requests", err_type="NotImplementedError", status_code=501, ) async def _handle_non_streaming_request( self, adapted_request: GenerateReqInput, request: OpenAIServingRequest, raw_request: Request, ) -> Union[Any, ErrorResponse, ORJSONResponse]: """Handle non-streaming request Override this method in child classes that support non-streaming requests. """ return self.create_error_response( message=f"{self.__class__.__name__} does not support non-streaming requests", err_type="NotImplementedError", status_code=501, ) def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]: """Validate request""" pass def create_error_response( self, message: str, err_type: str = "BadRequestError", status_code: int = 400, param: Optional[str] = None, ) -> ORJSONResponse: """Create an error response""" # TODO: remove fastapi dependency in openai and move response handling to the entrypoint error = ErrorResponse( object="error", message=message, type=err_type, param=param, code=status_code, ) return ORJSONResponse(content=error.model_dump(), status_code=status_code) def create_streaming_error_response( self, message: str, err_type: str = "BadRequestError", status_code: int = 400, ) -> str: """Create a streaming error response""" error = ErrorResponse( object="error", message=message, type=err_type, param=None, code=status_code, ) return json.dumps({"error": error.model_dump()}) def extract_custom_labels(self, raw_request): if ( not self.allowed_custom_labels or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header ): return None custom_labels = None header = ( self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header ) try: raw_labels = ( orjson.loads(raw_request.headers.get(header)) if raw_request and raw_request.headers.get(header) else None ) except json.JSONDecodeError as e: logger.exception(f"Error in request: {e}") raw_labels = None if isinstance(raw_labels, dict): custom_labels = { label: value for label, value in raw_labels.items() if label in self.allowed_custom_labels } return custom_labels