[OAI refactor] Add rerank and score serving (#7399)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -534,6 +534,22 @@ class ScoringResponse(BaseModel):
|
|||||||
object: str = "scoring"
|
object: str = "scoring"
|
||||||
|
|
||||||
|
|
||||||
|
class V1RerankReqInput(BaseModel):
|
||||||
|
query: str
|
||||||
|
documents: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
score: float
|
||||||
|
document: str
|
||||||
|
index: int
|
||||||
|
meta_info: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
OpenAIServingRequest = Union[
|
OpenAIServingRequest = Union[
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest
|
ChatCompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
EmbeddingRequest,
|
||||||
|
ScoringRequest,
|
||||||
|
V1RerankReqInput,
|
||||||
]
|
]
|
||||||
|
|||||||
98
python/sglang/srt/entrypoints/openai/serving_rerank.py
Normal file
98
python/sglang/srt/entrypoints/openai/serving_rerank.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
ErrorResponse,
|
||||||
|
RerankResponse,
|
||||||
|
V1RerankReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
|
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingRerank(OpenAIServingBase):
|
||||||
|
"""Handler for rerank requests"""
|
||||||
|
|
||||||
|
def _request_id_prefix(self) -> str:
|
||||||
|
return "rerank-"
|
||||||
|
|
||||||
|
def _validate_request(self, request: V1RerankReqInput) -> Optional[str]:
|
||||||
|
"""Validate rerank request format and content"""
|
||||||
|
if not request.query:
|
||||||
|
return "Query cannot be empty"
|
||||||
|
|
||||||
|
if isinstance(request.query, str):
|
||||||
|
if not request.query.strip():
|
||||||
|
return "Query cannot be empty or whitespace only"
|
||||||
|
|
||||||
|
if not request.documents:
|
||||||
|
return "Documents cannot be empty"
|
||||||
|
|
||||||
|
for doc in request.documents:
|
||||||
|
if not doc:
|
||||||
|
return "Each document must be a non-empty string"
|
||||||
|
if isinstance(doc, str) and not doc.strip():
|
||||||
|
return "Each document cannot be empty or whitespace only"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_to_internal_request(
|
||||||
|
self, request: V1RerankReqInput
|
||||||
|
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
|
||||||
|
"""Convert OpenAI rerank request to internal embedding format"""
|
||||||
|
# Create pairs of [query, document] for each document
|
||||||
|
pairs = []
|
||||||
|
for doc in request.documents:
|
||||||
|
pairs.append([request.query, doc])
|
||||||
|
|
||||||
|
adapted_request = EmbeddingReqInput(
|
||||||
|
text=pairs,
|
||||||
|
is_cross_encoder_request=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return adapted_request, request
|
||||||
|
|
||||||
|
async def _handle_non_streaming_request(
|
||||||
|
self,
|
||||||
|
adapted_request: EmbeddingReqInput,
|
||||||
|
request: V1RerankReqInput,
|
||||||
|
raw_request: Request,
|
||||||
|
) -> Union[RerankResponse, ErrorResponse]:
|
||||||
|
"""Handle the rerank request"""
|
||||||
|
try:
|
||||||
|
ret = await self.tokenizer_manager.generate_request(
|
||||||
|
adapted_request, raw_request
|
||||||
|
).__anext__()
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
if not isinstance(ret, list):
|
||||||
|
ret = [ret]
|
||||||
|
|
||||||
|
response = self._build_rerank_response(ret, request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _build_rerank_response(
|
||||||
|
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
|
||||||
|
) -> List[RerankResponse]:
|
||||||
|
"""Build the rerank response from generation results"""
|
||||||
|
response = []
|
||||||
|
for idx, ret_item in enumerate(ret):
|
||||||
|
response.append(
|
||||||
|
RerankResponse(
|
||||||
|
score=ret_item["embedding"],
|
||||||
|
document=request.documents[idx],
|
||||||
|
index=idx,
|
||||||
|
meta_info=ret_item["meta_info"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by score in descending order (highest relevance first)
|
||||||
|
response.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
return response
|
||||||
58
python/sglang/srt/entrypoints/openai/serving_score.py
Normal file
58
python/sglang/srt/entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
ErrorResponse,
|
||||||
|
ScoringRequest,
|
||||||
|
ScoringResponse,
|
||||||
|
)
|
||||||
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingScore(OpenAIServingBase):
|
||||||
|
"""Handler for scoring requests"""
|
||||||
|
|
||||||
|
def _request_id_prefix(self) -> str:
|
||||||
|
return "score-"
|
||||||
|
|
||||||
|
def _convert_to_internal_request(
|
||||||
|
self,
|
||||||
|
request: ScoringRequest,
|
||||||
|
) -> tuple[ScoringRequest, ScoringRequest]:
|
||||||
|
"""Convert OpenAI scoring request to internal format"""
|
||||||
|
# For scoring, we pass the request directly as the tokenizer_manager
|
||||||
|
# has a specialized score_request method that doesn't use GenerateReqInput
|
||||||
|
|
||||||
|
return request, request
|
||||||
|
|
||||||
|
async def _handle_non_streaming_request(
|
||||||
|
self,
|
||||||
|
adapted_request: ScoringRequest,
|
||||||
|
request: ScoringRequest,
|
||||||
|
raw_request: Request,
|
||||||
|
) -> Union[ScoringResponse, ErrorResponse]:
|
||||||
|
"""Handle the scoring request"""
|
||||||
|
try:
|
||||||
|
# Use tokenizer_manager's score_request method directly
|
||||||
|
scores = await self.tokenizer_manager.score_request(
|
||||||
|
query=request.query,
|
||||||
|
items=request.items,
|
||||||
|
label_token_ids=request.label_token_ids,
|
||||||
|
apply_softmax=request.apply_softmax,
|
||||||
|
item_first=request.item_first,
|
||||||
|
request=raw_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create response with just the scores, without usage info
|
||||||
|
response = ScoringResponse(
|
||||||
|
scores=scores,
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
Reference in New Issue
Block a user