[1/2][feature] support openai like classification api (#11618)

This commit is contained in:
ybyang
2025-10-19 10:32:48 +08:00
committed by GitHub
parent d513ee93ef
commit b5e14b2b78
4 changed files with 415 additions and 0 deletions

View File

@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationM
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
ClassifyRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest,
@@ -62,6 +63,7 @@ from sglang.srt.entrypoints.openai.protocol import (
V1RerankReqInput,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_classify import OpenAIServingClassify
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
@@ -228,6 +230,9 @@ async def lifespan(fast_api_app: FastAPI):
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_classify = OpenAIServingClassify(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_score = OpenAIServingScore(
_global_state.tokenizer_manager
)
@@ -1082,6 +1087,18 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
)
@app.post(
"/v1/classify",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
)
async def openai_v1_classify(request: ClassifyRequest, raw_request: Request):
"""OpenAI-compatible classification endpoint."""
return await raw_request.app.state.openai_serving_classify.handle_request(
request, raw_request
)
@app.post(
"/v1/tokenize",
response_class=ORJSONResponse,

View File

@@ -761,6 +761,37 @@ class EmbeddingObject(BaseModel):
object: str = "embedding"
ClassifyInput = Union[str, List[str], List[int]]
class ClassifyRequest(BaseModel):
# OpenAI-compatible classification request
model: str = DEFAULT_MODEL_NAME
input: ClassifyInput
user: Optional[str] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
class ClassifyData(BaseModel):
index: int
label: str
probs: List[float]
num_classes: int
class ClassifyResponse(BaseModel):
id: str
object: str = "list"
created: int
model: str
data: List[ClassifyData]
usage: UsageInfo
class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
@@ -844,6 +875,7 @@ OpenAIServingRequest = Union[
ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest,
ClassifyRequest,
ScoringRequest,
V1RerankReqInput,
TokenizeRequest,

View File

@@ -0,0 +1,204 @@
from __future__ import annotations
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import (
ClassifyRequest,
ClassifyResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
if TYPE_CHECKING:
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
class OpenAIServingClassify(OpenAIServingBase):
"""Handler for v1/classify requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.id2label = self._get_id2label_mapping()
self.model_name = (
self.tokenizer_manager.served_model_name
if self.tokenizer_manager.served_model_name
else self.tokenizer_manager.server_args.model_path
)
if not self.id2label:
raise ValueError("id2label mapping is missing")
def _request_id_prefix(self) -> str:
return "classify-"
def _convert_to_internal_request(
self,
request: ClassifyRequest,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, ClassifyRequest]:
"""Convert OpenAI embedding request to internal format"""
prompt = request.input
if isinstance(prompt, str):
# Single string input
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
else:
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
else:
# Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput(
**prompt_kwargs,
rid=request.rid,
priority=request.priority,
)
return adapted_request, request
def _validate_request(self, request: ClassifyRequest) -> Optional[str]:
"""Validate that the input is not empty or whitespace only."""
if not (input := request.input):
return "Input cannot be empty"
# Handle single string
if isinstance(input, str):
if not input.strip():
return "Input cannot be empty or whitespace only"
return None
# Handle list inputs
if isinstance(input, list):
# Check first element to determine type
first_item = input[0]
if isinstance(first_item, str):
# List of strings
for i, item in enumerate(input):
if not isinstance(item, str):
return f"All items in input list must be strings"
if not item.strip():
return f"Input at index {i} cannot be empty or whitespace only"
elif isinstance(first_item, int):
# List of integers (token IDs)
for i, item in enumerate(input):
if not isinstance(item, int):
return f"All items in input list must be integers"
if item < 0:
return f"Token ID at index {i} must be non-negative"
return None
def _get_id2label_mapping(self) -> Optional[Dict[int, str]]:
"""Get id2label mapping from model config."""
try:
hf_config = self.tokenizer_manager.model_config.hf_config
# Check for id2label in hf_config
if hf_config.id2label:
return hf_config.id2label
# Check for num_labels and create default mapping if needed
if hasattr(hf_config, "num_labels") and hf_config.num_labels:
num_labels = hf_config.num_labels
# Create default mapping: {0: "LABEL_0", 1: "LABEL_1", ...}
return {i: f"LABEL_{i}" for i in range(num_labels)}
except Exception as e:
logger.warning(f"Failed to get id2label mapping: {e}")
return None
async def _handle_non_streaming_request(
self,
adapted_request: EmbeddingReqInput,
request: ClassifyRequest,
raw_request: Request,
) -> Union[ClassifyResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming classification request."""
# Generate request ID
try:
ret = await self.tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_classify_response(ret)
return response
def _build_classify_response(self, ret: List[Dict[str, Any]]) -> ClassifyResponse:
request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}"
created_time = int(time.time())
classify_objects = []
prompt_tokens = 0
total_latency = 0.0
for i, item in enumerate(ret):
embedding = item.get("embedding", [])
meta_info = item.get("meta_info", {})
prompt_tokens += meta_info.get("prompt_tokens", 0)
total_latency += meta_info.get("e2e_latency", 0.0)
if embedding:
try:
embedding_tensor = torch.tensor(embedding, dtype=torch.float32)
probs = F.softmax(embedding_tensor, dim=0).tolist()
predicted_class = torch.argmax(embedding_tensor).item()
label = self.id2label[predicted_class]
except Exception as e:
logger.error(f"Error processing embedding for item {i}: {e}")
probs = [1.0]
label = "Default"
else:
probs = [1.0]
label = "Default"
classify_obj = {
"index": i,
"label": label,
"probs": probs,
"num_classes": len(probs),
}
classify_objects.append(classify_obj)
response = {
"id": request_id,
"object": "list",
"created": created_time,
"model": self.model_name,
"data": classify_objects,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
"completion_tokens": 0,
"prompt_tokens_details": None,
},
}
return ClassifyResponse(**response)