[1/2][feature] support openai like classification api (#11618)
This commit is contained in:
162
docs/supported_models/classify_models.md
Normal file
162
docs/supported_models/classify_models.md
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
# Classification API
|
||||||
|
|
||||||
|
This document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API.
|
||||||
|
|
||||||
|
## API Endpoint
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/classify
|
||||||
|
```
|
||||||
|
|
||||||
|
## Request Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "model_name",
|
||||||
|
"input": "text to classify"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
- `model` (string, required): The name of the classification model to use
|
||||||
|
- `input` (string, required): The text to classify
|
||||||
|
- `user` (string, optional): User identifier for tracking
|
||||||
|
- `rid` (string, optional): Request ID for tracking
|
||||||
|
- `priority` (integer, optional): Request priority
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
|
||||||
|
"object": "list",
|
||||||
|
"created": 1745383213,
|
||||||
|
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"label": "Default",
|
||||||
|
"probs": [0.565970778465271, 0.4340292513370514],
|
||||||
|
"num_classes": 2
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"total_tokens": 10,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"prompt_tokens_details": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Response Fields
|
||||||
|
|
||||||
|
- `id`: Unique identifier for the classification request
|
||||||
|
- `object`: Always "list"
|
||||||
|
- `created`: Unix timestamp when the request was created
|
||||||
|
- `model`: The model used for classification
|
||||||
|
- `data`: Array of classification results
|
||||||
|
- `index`: Index of the result
|
||||||
|
- `label`: Predicted class label
|
||||||
|
- `probs`: Array of probabilities for each class
|
||||||
|
- `num_classes`: Total number of classes
|
||||||
|
- `usage`: Token usage information
|
||||||
|
- `prompt_tokens`: Number of input tokens
|
||||||
|
- `total_tokens`: Total number of tokens
|
||||||
|
- `completion_tokens`: Number of completion tokens (always 0 for classification)
|
||||||
|
- `prompt_tokens_details`: Additional token details (optional)
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
### Using curl
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -v "http://127.0.0.1:8000/v1/classify" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
"input": "Loved the new café—coffee was great."
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Python
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Make classification request
|
||||||
|
response = requests.post(
|
||||||
|
"http://127.0.0.1:8000/v1/classify",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={
|
||||||
|
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
"input": "Loved the new café—coffee was great."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
result = response.json()
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
The classification API works with any classification model supported by SGLang, including:
|
||||||
|
|
||||||
|
### Classification Models (Multi-class)
|
||||||
|
- `LlamaForSequenceClassification` - Multi-class classification
|
||||||
|
- `Qwen2ForSequenceClassification` - Multi-class classification
|
||||||
|
- `Qwen3ForSequenceClassification` - Multi-class classification
|
||||||
|
- `BertForSequenceClassification` - Multi-class classification
|
||||||
|
- `Gemma2ForSequenceClassification` - Multi-class classification
|
||||||
|
|
||||||
|
**Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort.
|
||||||
|
|
||||||
|
### Reward Models (Single score)
|
||||||
|
- `InternLM2ForRewardModel` - Single reward score
|
||||||
|
- `Qwen2ForRewardModel` - Single reward score
|
||||||
|
- `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model
|
||||||
|
|
||||||
|
**Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The API returns appropriate HTTP status codes and error messages:
|
||||||
|
|
||||||
|
- `400 Bad Request`: Invalid request format or missing required fields
|
||||||
|
- `500 Internal Server Error`: Server-side processing error
|
||||||
|
|
||||||
|
Error response format:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "Error message",
|
||||||
|
"type": "error_type",
|
||||||
|
"code": 400
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
The classification API is implemented using:
|
||||||
|
|
||||||
|
1. **Rust Router**: Handles routing and request/response models in `sgl-router/src/protocols/spec.rs`
|
||||||
|
2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py`
|
||||||
|
3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py`
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Use the provided test script to verify the implementation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python test_classify_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Compatibility
|
||||||
|
|
||||||
|
This implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks.
|
||||||
@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationM
|
|||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
ClassifyRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
@@ -62,6 +63,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
V1RerankReqInput,
|
V1RerankReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from sglang.srt.entrypoints.openai.serving_classify import OpenAIServingClassify
|
||||||
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
||||||
@@ -228,6 +230,9 @@ async def lifespan(fast_api_app: FastAPI):
|
|||||||
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
|
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
_global_state.tokenizer_manager, _global_state.template_manager
|
_global_state.tokenizer_manager, _global_state.template_manager
|
||||||
)
|
)
|
||||||
|
fast_api_app.state.openai_serving_classify = OpenAIServingClassify(
|
||||||
|
_global_state.tokenizer_manager, _global_state.template_manager
|
||||||
|
)
|
||||||
fast_api_app.state.openai_serving_score = OpenAIServingScore(
|
fast_api_app.state.openai_serving_score = OpenAIServingScore(
|
||||||
_global_state.tokenizer_manager
|
_global_state.tokenizer_manager
|
||||||
)
|
)
|
||||||
@@ -1082,6 +1087,18 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/v1/classify",
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
)
|
||||||
|
async def openai_v1_classify(request: ClassifyRequest, raw_request: Request):
|
||||||
|
"""OpenAI-compatible classification endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_classify.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/v1/tokenize",
|
"/v1/tokenize",
|
||||||
response_class=ORJSONResponse,
|
response_class=ORJSONResponse,
|
||||||
|
|||||||
@@ -761,6 +761,37 @@ class EmbeddingObject(BaseModel):
|
|||||||
object: str = "embedding"
|
object: str = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
ClassifyInput = Union[str, List[str], List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyRequest(BaseModel):
|
||||||
|
# OpenAI-compatible classification request
|
||||||
|
model: str = DEFAULT_MODEL_NAME
|
||||||
|
input: ClassifyInput
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# The request id.
|
||||||
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyData(BaseModel):
|
||||||
|
index: int
|
||||||
|
label: str
|
||||||
|
probs: List[float]
|
||||||
|
num_classes: int
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "list"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
data: List[ClassifyData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(BaseModel):
|
class EmbeddingResponse(BaseModel):
|
||||||
data: List[EmbeddingObject]
|
data: List[EmbeddingObject]
|
||||||
model: str
|
model: str
|
||||||
@@ -844,6 +875,7 @@ OpenAIServingRequest = Union[
|
|||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
|
ClassifyRequest,
|
||||||
ScoringRequest,
|
ScoringRequest,
|
||||||
V1RerankReqInput,
|
V1RerankReqInput,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
|
|||||||
204
python/sglang/srt/entrypoints/openai/serving_classify.py
Normal file
204
python/sglang/srt/entrypoints/openai/serving_classify.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import ORJSONResponse
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
ClassifyRequest,
|
||||||
|
ClassifyResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
)
|
||||||
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
|
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingClassify(OpenAIServingBase):
|
||||||
|
"""Handler for v1/classify requests"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_manager: TokenizerManager,
|
||||||
|
template_manager: TemplateManager,
|
||||||
|
):
|
||||||
|
super().__init__(tokenizer_manager)
|
||||||
|
self.template_manager = template_manager
|
||||||
|
self.id2label = self._get_id2label_mapping()
|
||||||
|
self.model_name = (
|
||||||
|
self.tokenizer_manager.served_model_name
|
||||||
|
if self.tokenizer_manager.served_model_name
|
||||||
|
else self.tokenizer_manager.server_args.model_path
|
||||||
|
)
|
||||||
|
if not self.id2label:
|
||||||
|
raise ValueError("id2label mapping is missing")
|
||||||
|
|
||||||
|
def _request_id_prefix(self) -> str:
|
||||||
|
return "classify-"
|
||||||
|
|
||||||
|
def _convert_to_internal_request(
|
||||||
|
self,
|
||||||
|
request: ClassifyRequest,
|
||||||
|
raw_request: Request = None,
|
||||||
|
) -> tuple[EmbeddingReqInput, ClassifyRequest]:
|
||||||
|
"""Convert OpenAI embedding request to internal format"""
|
||||||
|
prompt = request.input
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
# Single string input
|
||||||
|
prompt_kwargs = {"text": prompt}
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||||
|
prompt_kwargs = {"text": prompt}
|
||||||
|
else:
|
||||||
|
# List of integers (token IDs) or empty list
|
||||||
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
else:
|
||||||
|
# Other types (should not happen but handle gracefully)
|
||||||
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
|
||||||
|
adapted_request = EmbeddingReqInput(
|
||||||
|
**prompt_kwargs,
|
||||||
|
rid=request.rid,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
return adapted_request, request
|
||||||
|
|
||||||
|
def _validate_request(self, request: ClassifyRequest) -> Optional[str]:
|
||||||
|
"""Validate that the input is not empty or whitespace only."""
|
||||||
|
if not (input := request.input):
|
||||||
|
return "Input cannot be empty"
|
||||||
|
|
||||||
|
# Handle single string
|
||||||
|
if isinstance(input, str):
|
||||||
|
if not input.strip():
|
||||||
|
return "Input cannot be empty or whitespace only"
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle list inputs
|
||||||
|
if isinstance(input, list):
|
||||||
|
# Check first element to determine type
|
||||||
|
first_item = input[0]
|
||||||
|
|
||||||
|
if isinstance(first_item, str):
|
||||||
|
# List of strings
|
||||||
|
for i, item in enumerate(input):
|
||||||
|
if not isinstance(item, str):
|
||||||
|
return f"All items in input list must be strings"
|
||||||
|
if not item.strip():
|
||||||
|
return f"Input at index {i} cannot be empty or whitespace only"
|
||||||
|
elif isinstance(first_item, int):
|
||||||
|
# List of integers (token IDs)
|
||||||
|
for i, item in enumerate(input):
|
||||||
|
if not isinstance(item, int):
|
||||||
|
return f"All items in input list must be integers"
|
||||||
|
if item < 0:
|
||||||
|
return f"Token ID at index {i} must be non-negative"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_id2label_mapping(self) -> Optional[Dict[int, str]]:
|
||||||
|
"""Get id2label mapping from model config."""
|
||||||
|
try:
|
||||||
|
hf_config = self.tokenizer_manager.model_config.hf_config
|
||||||
|
# Check for id2label in hf_config
|
||||||
|
if hf_config.id2label:
|
||||||
|
return hf_config.id2label
|
||||||
|
# Check for num_labels and create default mapping if needed
|
||||||
|
if hasattr(hf_config, "num_labels") and hf_config.num_labels:
|
||||||
|
num_labels = hf_config.num_labels
|
||||||
|
# Create default mapping: {0: "LABEL_0", 1: "LABEL_1", ...}
|
||||||
|
return {i: f"LABEL_{i}" for i in range(num_labels)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get id2label mapping: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _handle_non_streaming_request(
|
||||||
|
self,
|
||||||
|
adapted_request: EmbeddingReqInput,
|
||||||
|
request: ClassifyRequest,
|
||||||
|
raw_request: Request,
|
||||||
|
) -> Union[ClassifyResponse, ErrorResponse, ORJSONResponse]:
|
||||||
|
"""Handle non-streaming classification request."""
|
||||||
|
# Generate request ID
|
||||||
|
|
||||||
|
try:
|
||||||
|
ret = await self.tokenizer_manager.generate_request(
|
||||||
|
adapted_request, raw_request
|
||||||
|
).__anext__()
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
if not isinstance(ret, list):
|
||||||
|
ret = [ret]
|
||||||
|
|
||||||
|
response = self._build_classify_response(ret)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _build_classify_response(self, ret: List[Dict[str, Any]]) -> ClassifyResponse:
|
||||||
|
request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
||||||
|
created_time = int(time.time())
|
||||||
|
classify_objects = []
|
||||||
|
prompt_tokens = 0
|
||||||
|
total_latency = 0.0
|
||||||
|
|
||||||
|
for i, item in enumerate(ret):
|
||||||
|
embedding = item.get("embedding", [])
|
||||||
|
meta_info = item.get("meta_info", {})
|
||||||
|
|
||||||
|
prompt_tokens += meta_info.get("prompt_tokens", 0)
|
||||||
|
total_latency += meta_info.get("e2e_latency", 0.0)
|
||||||
|
|
||||||
|
if embedding:
|
||||||
|
try:
|
||||||
|
embedding_tensor = torch.tensor(embedding, dtype=torch.float32)
|
||||||
|
probs = F.softmax(embedding_tensor, dim=0).tolist()
|
||||||
|
|
||||||
|
predicted_class = torch.argmax(embedding_tensor).item()
|
||||||
|
|
||||||
|
label = self.id2label[predicted_class]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing embedding for item {i}: {e}")
|
||||||
|
probs = [1.0]
|
||||||
|
label = "Default"
|
||||||
|
else:
|
||||||
|
probs = [1.0]
|
||||||
|
label = "Default"
|
||||||
|
|
||||||
|
classify_obj = {
|
||||||
|
"index": i,
|
||||||
|
"label": label,
|
||||||
|
"probs": probs,
|
||||||
|
"num_classes": len(probs),
|
||||||
|
}
|
||||||
|
classify_objects.append(classify_obj)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"id": request_id,
|
||||||
|
"object": "list",
|
||||||
|
"created": created_time,
|
||||||
|
"model": self.model_name,
|
||||||
|
"data": classify_objects,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"total_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"prompt_tokens_details": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return ClassifyResponse(**response)
|
||||||
Reference in New Issue
Block a user