From cfb2fb5afc6f9361040f8206a7180e41824aeb1d Mon Sep 17 00:00:00 2001 From: woodx <124784234+woodx9@users.noreply.github.com> Date: Sat, 21 Jun 2025 05:51:10 +0800 Subject: [PATCH] [OAI refactor] Add rerank and score serving (#7399) Co-authored-by: Chang Su --- .../sglang/srt/entrypoints/openai/protocol.py | 18 +++- .../srt/entrypoints/openai/serving_rerank.py | 98 +++++++++++++++++++ .../srt/entrypoints/openai/serving_score.py | 58 +++++++++++ 3 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/entrypoints/openai/serving_rerank.py create mode 100644 python/sglang/srt/entrypoints/openai/serving_score.py diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index c7423ed1b..017097e11 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -534,6 +534,22 @@ class ScoringResponse(BaseModel): object: str = "scoring" +class V1RerankReqInput(BaseModel): + query: str + documents: List[str] + + +class RerankResponse(BaseModel): + score: float + document: str + index: int + meta_info: Optional[dict] = None + + OpenAIServingRequest = Union[ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest + ChatCompletionRequest, + CompletionRequest, + EmbeddingRequest, + ScoringRequest, + V1RerankReqInput, ] diff --git a/python/sglang/srt/entrypoints/openai/serving_rerank.py b/python/sglang/srt/entrypoints/openai/serving_rerank.py new file mode 100644 index 000000000..50be5c3cc --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_rerank.py @@ -0,0 +1,98 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request + +from sglang.srt.entrypoints.openai.protocol import ( + ErrorResponse, + RerankResponse, + V1RerankReqInput, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.managers.io_struct import EmbeddingReqInput + +logger = logging.getLogger(__name__) + + +class OpenAIServingRerank(OpenAIServingBase): + """Handler for rerank requests""" + + def _request_id_prefix(self) -> str: + return "rerank-" + + def _validate_request(self, request: V1RerankReqInput) -> Optional[str]: + """Validate rerank request format and content""" + if not request.query: + return "Query cannot be empty" + + if isinstance(request.query, str): + if not request.query.strip(): + return "Query cannot be empty or whitespace only" + + if not request.documents: + return "Documents cannot be empty" + + for doc in request.documents: + if not doc: + return "Each document must be a non-empty string" + if isinstance(doc, str) and not doc.strip(): + return "Each document cannot be empty or whitespace only" + + return None + + def _convert_to_internal_request( + self, request: V1RerankReqInput + ) -> tuple[EmbeddingReqInput, V1RerankReqInput]: + """Convert OpenAI rerank request to internal embedding format""" + # Create pairs of [query, document] for each document + pairs = [] + for doc in request.documents: + pairs.append([request.query, doc]) + + adapted_request = EmbeddingReqInput( + text=pairs, + is_cross_encoder_request=True, + ) + + return adapted_request, request + + async def _handle_non_streaming_request( + self, + adapted_request: EmbeddingReqInput, + request: V1RerankReqInput, + raw_request: Request, + ) -> Union[RerankResponse, ErrorResponse]: + """Handle the rerank request""" + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + + except ValueError as e: + return self.create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_rerank_response(ret, request) + return response + + def _build_rerank_response( + self, ret: List[Dict[str, Any]], request: V1RerankReqInput + ) -> List[RerankResponse]: + """Build the rerank response from generation results""" + response = [] + for idx, ret_item in enumerate(ret): + response.append( + RerankResponse( + score=ret_item["embedding"], + document=request.documents[idx], + index=idx, + meta_info=ret_item["meta_info"], + ) + ) + + # Sort by score in descending order (highest relevance first) + response.sort(key=lambda x: x.score, reverse=True) + + return response diff --git a/python/sglang/srt/entrypoints/openai/serving_score.py b/python/sglang/srt/entrypoints/openai/serving_score.py new file mode 100644 index 000000000..af73a866a --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_score.py @@ -0,0 +1,58 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request + +from sglang.srt.entrypoints.openai.protocol import ( + ErrorResponse, + ScoringRequest, + ScoringResponse, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase + +logger = logging.getLogger(__name__) + + +class OpenAIServingScore(OpenAIServingBase): + """Handler for scoring requests""" + + def _request_id_prefix(self) -> str: + return "score-" + + def _convert_to_internal_request( + self, + request: ScoringRequest, + ) -> tuple[ScoringRequest, ScoringRequest]: + """Convert OpenAI scoring request to internal format""" + # For scoring, we pass the request directly as the tokenizer_manager + # has a specialized score_request method that doesn't use GenerateReqInput + + return request, request + + async def _handle_non_streaming_request( + self, + adapted_request: ScoringRequest, + request: ScoringRequest, + raw_request: Request, + ) -> Union[ScoringResponse, ErrorResponse]: + """Handle the scoring request""" + try: + # Use tokenizer_manager's score_request method directly + scores = await self.tokenizer_manager.score_request( + query=request.query, + items=request.items, + label_token_ids=request.label_token_ids, + apply_softmax=request.apply_softmax, + item_first=request.item_first, + request=raw_request, + ) + + # Create response with just the scores, without usage info + response = ScoringResponse( + scores=scores, + model=request.model, + ) + return response + + except ValueError as e: + return self.create_error_response(str(e))