[OAI refactor] Add rerank and score serving (#7399)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -534,6 +534,22 @@ class ScoringResponse(BaseModel):
|
||||
object: str = "scoring"
|
||||
|
||||
|
||||
class V1RerankReqInput(BaseModel):
|
||||
query: str
|
||||
documents: List[str]
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
score: float
|
||||
document: str
|
||||
index: int
|
||||
meta_info: Optional[dict] = None
|
||||
|
||||
|
||||
OpenAIServingRequest = Union[
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
EmbeddingRequest,
|
||||
ScoringRequest,
|
||||
V1RerankReqInput,
|
||||
]
|
||||
|
||||
98
python/sglang/srt/entrypoints/openai/serving_rerank.py
Normal file
98
python/sglang/srt/entrypoints/openai/serving_rerank.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
RerankResponse,
|
||||
V1RerankReqInput,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingRerank(OpenAIServingBase):
|
||||
"""Handler for rerank requests"""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "rerank-"
|
||||
|
||||
def _validate_request(self, request: V1RerankReqInput) -> Optional[str]:
|
||||
"""Validate rerank request format and content"""
|
||||
if not request.query:
|
||||
return "Query cannot be empty"
|
||||
|
||||
if isinstance(request.query, str):
|
||||
if not request.query.strip():
|
||||
return "Query cannot be empty or whitespace only"
|
||||
|
||||
if not request.documents:
|
||||
return "Documents cannot be empty"
|
||||
|
||||
for doc in request.documents:
|
||||
if not doc:
|
||||
return "Each document must be a non-empty string"
|
||||
if isinstance(doc, str) and not doc.strip():
|
||||
return "Each document cannot be empty or whitespace only"
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self, request: V1RerankReqInput
|
||||
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
|
||||
"""Convert OpenAI rerank request to internal embedding format"""
|
||||
# Create pairs of [query, document] for each document
|
||||
pairs = []
|
||||
for doc in request.documents:
|
||||
pairs.append([request.query, doc])
|
||||
|
||||
adapted_request = EmbeddingReqInput(
|
||||
text=pairs,
|
||||
is_cross_encoder_request=True,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: EmbeddingReqInput,
|
||||
request: V1RerankReqInput,
|
||||
raw_request: Request,
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""Handle the rerank request"""
|
||||
try:
|
||||
ret = await self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
).__anext__()
|
||||
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = self._build_rerank_response(ret, request)
|
||||
return response
|
||||
|
||||
def _build_rerank_response(
|
||||
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
|
||||
) -> List[RerankResponse]:
|
||||
"""Build the rerank response from generation results"""
|
||||
response = []
|
||||
for idx, ret_item in enumerate(ret):
|
||||
response.append(
|
||||
RerankResponse(
|
||||
score=ret_item["embedding"],
|
||||
document=request.documents[idx],
|
||||
index=idx,
|
||||
meta_info=ret_item["meta_info"],
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by score in descending order (highest relevance first)
|
||||
response.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return response
|
||||
58
python/sglang/srt/entrypoints/openai/serving_score.py
Normal file
58
python/sglang/srt/entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingScore(OpenAIServingBase):
|
||||
"""Handler for scoring requests"""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "score-"
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: ScoringRequest,
|
||||
) -> tuple[ScoringRequest, ScoringRequest]:
|
||||
"""Convert OpenAI scoring request to internal format"""
|
||||
# For scoring, we pass the request directly as the tokenizer_manager
|
||||
# has a specialized score_request method that doesn't use GenerateReqInput
|
||||
|
||||
return request, request
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: ScoringRequest,
|
||||
request: ScoringRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[ScoringResponse, ErrorResponse]:
|
||||
"""Handle the scoring request"""
|
||||
try:
|
||||
# Use tokenizer_manager's score_request method directly
|
||||
scores = await self.tokenizer_manager.score_request(
|
||||
query=request.query,
|
||||
items=request.items,
|
||||
label_token_ids=request.label_token_ids,
|
||||
apply_softmax=request.apply_softmax,
|
||||
item_first=request.item_first,
|
||||
request=raw_request,
|
||||
)
|
||||
|
||||
# Create response with just the scores, without usage info
|
||||
response = ScoringResponse(
|
||||
scores=scores,
|
||||
model=request.model,
|
||||
)
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
Reference in New Issue
Block a user