From b5e14b2b78b201dd258400c7a57fd19e25fdc890 Mon Sep 17 00:00:00 2001 From: ybyang <10629930+whybeyoung@users.noreply.github.com> Date: Sun, 19 Oct 2025 10:32:48 +0800 Subject: [PATCH] [1/2][feature] support openai like classification api (#11618) --- docs/supported_models/classify_models.md | 162 ++++++++++++++ python/sglang/srt/entrypoints/http_server.py | 17 ++ .../sglang/srt/entrypoints/openai/protocol.py | 32 +++ .../entrypoints/openai/serving_classify.py | 204 ++++++++++++++++++ 4 files changed, 415 insertions(+) create mode 100644 docs/supported_models/classify_models.md create mode 100644 python/sglang/srt/entrypoints/openai/serving_classify.py diff --git a/docs/supported_models/classify_models.md b/docs/supported_models/classify_models.md new file mode 100644 index 000000000..c6d18f9a9 --- /dev/null +++ b/docs/supported_models/classify_models.md @@ -0,0 +1,162 @@ +# Classification API + +This document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format. + +## Overview + +The classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API. + +## API Endpoint + +``` +POST /v1/classify +``` + +## Request Format + +```json +{ + "model": "model_name", + "input": "text to classify" +} +``` + +### Parameters + +- `model` (string, required): The name of the classification model to use +- `input` (string, required): The text to classify +- `user` (string, optional): User identifier for tracking +- `rid` (string, optional): Request ID for tracking +- `priority` (integer, optional): Request priority + +## Response Format + +```json +{ + "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682", + "object": "list", + "created": 1745383213, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [0.565970778465271, 0.4340292513370514], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 10, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +### Response Fields + +- `id`: Unique identifier for the classification request +- `object`: Always "list" +- `created`: Unix timestamp when the request was created +- `model`: The model used for classification +- `data`: Array of classification results + - `index`: Index of the result + - `label`: Predicted class label + - `probs`: Array of probabilities for each class + - `num_classes`: Total number of classes +- `usage`: Token usage information + - `prompt_tokens`: Number of input tokens + - `total_tokens`: Total number of tokens + - `completion_tokens`: Number of completion tokens (always 0 for classification) + - `prompt_tokens_details`: Additional token details (optional) + +## Example Usage + +### Using curl + +```bash +curl -v "http://127.0.0.1:8000/v1/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new café—coffee was great." + }' +``` + +### Using Python + +```python +import requests +import json + +# Make classification request +response = requests.post( + "http://127.0.0.1:8000/v1/classify", + headers={"Content-Type": "application/json"}, + json={ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new café—coffee was great." + } +) + +# Parse response +result = response.json() +print(json.dumps(result, indent=2)) +``` + +## Supported Models + +The classification API works with any classification model supported by SGLang, including: + +### Classification Models (Multi-class) +- `LlamaForSequenceClassification` - Multi-class classification +- `Qwen2ForSequenceClassification` - Multi-class classification +- `Qwen3ForSequenceClassification` - Multi-class classification +- `BertForSequenceClassification` - Multi-class classification +- `Gemma2ForSequenceClassification` - Multi-class classification + +**Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort. + +### Reward Models (Single score) +- `InternLM2ForRewardModel` - Single reward score +- `Qwen2ForRewardModel` - Single reward score +- `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model + +**Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks. + +## Error Handling + +The API returns appropriate HTTP status codes and error messages: + +- `400 Bad Request`: Invalid request format or missing required fields +- `500 Internal Server Error`: Server-side processing error + +Error response format: +```json +{ + "error": "Error message", + "type": "error_type", + "code": 400 +} +``` + +## Implementation Details + +The classification API is implemented using: + +1. **Rust Router**: Handles routing and request/response models in `sgl-router/src/protocols/spec.rs` +2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py` +3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py` + +## Testing + +Use the provided test script to verify the implementation: + +```bash +python test_classify_api.py +``` + +## Compatibility + +This implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks. diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 00fe4ca17..982e6467e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationM from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, + ClassifyRequest, CompletionRequest, DetokenizeRequest, EmbeddingRequest, @@ -62,6 +63,7 @@ from sglang.srt.entrypoints.openai.protocol import ( V1RerankReqInput, ) from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat +from sglang.srt.entrypoints.openai.serving_classify import OpenAIServingClassify from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank @@ -228,6 +230,9 @@ async def lifespan(fast_api_app: FastAPI): fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding( _global_state.tokenizer_manager, _global_state.template_manager ) + fast_api_app.state.openai_serving_classify = OpenAIServingClassify( + _global_state.tokenizer_manager, _global_state.template_manager + ) fast_api_app.state.openai_serving_score = OpenAIServingScore( _global_state.tokenizer_manager ) @@ -1082,6 +1087,18 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request): ) +@app.post( + "/v1/classify", + response_class=ORJSONResponse, + dependencies=[Depends(validate_json_request)], +) +async def openai_v1_classify(request: ClassifyRequest, raw_request: Request): + """OpenAI-compatible classification endpoint.""" + return await raw_request.app.state.openai_serving_classify.handle_request( + request, raw_request + ) + + @app.post( "/v1/tokenize", response_class=ORJSONResponse, diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 871dcfd06..46ebc6687 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -761,6 +761,37 @@ class EmbeddingObject(BaseModel): object: str = "embedding" +ClassifyInput = Union[str, List[str], List[int]] + + +class ClassifyRequest(BaseModel): + # OpenAI-compatible classification request + model: str = DEFAULT_MODEL_NAME + input: ClassifyInput + user: Optional[str] = None + + # The request id. + rid: Optional[Union[List[str], str]] = None + # Priority for the request + priority: Optional[int] = None + + +class ClassifyData(BaseModel): + index: int + label: str + probs: List[float] + num_classes: int + + +class ClassifyResponse(BaseModel): + id: str + object: str = "list" + created: int + model: str + data: List[ClassifyData] + usage: UsageInfo + + class EmbeddingResponse(BaseModel): data: List[EmbeddingObject] model: str @@ -844,6 +875,7 @@ OpenAIServingRequest = Union[ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, + ClassifyRequest, ScoringRequest, V1RerankReqInput, TokenizeRequest, diff --git a/python/sglang/srt/entrypoints/openai/serving_classify.py b/python/sglang/srt/entrypoints/openai/serving_classify.py new file mode 100644 index 000000000..6b2a64abb --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_classify.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import logging +import time +import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from fastapi import Request +from fastapi.responses import ORJSONResponse + +from sglang.srt.entrypoints.openai.protocol import ( + ClassifyRequest, + ClassifyResponse, + ErrorResponse, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.managers.io_struct import EmbeddingReqInput + +if TYPE_CHECKING: + from sglang.srt.managers.template_manager import TemplateManager + from sglang.srt.managers.tokenizer_manager import TokenizerManager + +logger = logging.getLogger(__name__) + + +class OpenAIServingClassify(OpenAIServingBase): + """Handler for v1/classify requests""" + + def __init__( + self, + tokenizer_manager: TokenizerManager, + template_manager: TemplateManager, + ): + super().__init__(tokenizer_manager) + self.template_manager = template_manager + self.id2label = self._get_id2label_mapping() + self.model_name = ( + self.tokenizer_manager.served_model_name + if self.tokenizer_manager.served_model_name + else self.tokenizer_manager.server_args.model_path + ) + if not self.id2label: + raise ValueError("id2label mapping is missing") + + def _request_id_prefix(self) -> str: + return "classify-" + + def _convert_to_internal_request( + self, + request: ClassifyRequest, + raw_request: Request = None, + ) -> tuple[EmbeddingReqInput, ClassifyRequest]: + """Convert OpenAI embedding request to internal format""" + prompt = request.input + + if isinstance(prompt, str): + # Single string input + prompt_kwargs = {"text": prompt} + elif isinstance(prompt, list): + if len(prompt) > 0 and isinstance(prompt[0], str): + prompt_kwargs = {"text": prompt} + else: + # List of integers (token IDs) or empty list + prompt_kwargs = {"input_ids": prompt} + else: + # Other types (should not happen but handle gracefully) + prompt_kwargs = {"input_ids": prompt} + + adapted_request = EmbeddingReqInput( + **prompt_kwargs, + rid=request.rid, + priority=request.priority, + ) + + return adapted_request, request + + def _validate_request(self, request: ClassifyRequest) -> Optional[str]: + """Validate that the input is not empty or whitespace only.""" + if not (input := request.input): + return "Input cannot be empty" + + # Handle single string + if isinstance(input, str): + if not input.strip(): + return "Input cannot be empty or whitespace only" + return None + + # Handle list inputs + if isinstance(input, list): + # Check first element to determine type + first_item = input[0] + + if isinstance(first_item, str): + # List of strings + for i, item in enumerate(input): + if not isinstance(item, str): + return f"All items in input list must be strings" + if not item.strip(): + return f"Input at index {i} cannot be empty or whitespace only" + elif isinstance(first_item, int): + # List of integers (token IDs) + for i, item in enumerate(input): + if not isinstance(item, int): + return f"All items in input list must be integers" + if item < 0: + return f"Token ID at index {i} must be non-negative" + return None + + def _get_id2label_mapping(self) -> Optional[Dict[int, str]]: + """Get id2label mapping from model config.""" + try: + hf_config = self.tokenizer_manager.model_config.hf_config + # Check for id2label in hf_config + if hf_config.id2label: + return hf_config.id2label + # Check for num_labels and create default mapping if needed + if hasattr(hf_config, "num_labels") and hf_config.num_labels: + num_labels = hf_config.num_labels + # Create default mapping: {0: "LABEL_0", 1: "LABEL_1", ...} + return {i: f"LABEL_{i}" for i in range(num_labels)} + + except Exception as e: + logger.warning(f"Failed to get id2label mapping: {e}") + + return None + + async def _handle_non_streaming_request( + self, + adapted_request: EmbeddingReqInput, + request: ClassifyRequest, + raw_request: Request, + ) -> Union[ClassifyResponse, ErrorResponse, ORJSONResponse]: + """Handle non-streaming classification request.""" + # Generate request ID + + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_classify_response(ret) + return response + + def _build_classify_response(self, ret: List[Dict[str, Any]]) -> ClassifyResponse: + request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}" + created_time = int(time.time()) + classify_objects = [] + prompt_tokens = 0 + total_latency = 0.0 + + for i, item in enumerate(ret): + embedding = item.get("embedding", []) + meta_info = item.get("meta_info", {}) + + prompt_tokens += meta_info.get("prompt_tokens", 0) + total_latency += meta_info.get("e2e_latency", 0.0) + + if embedding: + try: + embedding_tensor = torch.tensor(embedding, dtype=torch.float32) + probs = F.softmax(embedding_tensor, dim=0).tolist() + + predicted_class = torch.argmax(embedding_tensor).item() + + label = self.id2label[predicted_class] + + except Exception as e: + logger.error(f"Error processing embedding for item {i}: {e}") + probs = [1.0] + label = "Default" + else: + probs = [1.0] + label = "Default" + + classify_obj = { + "index": i, + "label": label, + "probs": probs, + "num_classes": len(probs), + } + classify_objects.append(classify_obj) + + response = { + "id": request_id, + "object": "list", + "created": created_time, + "model": self.model_name, + "data": classify_objects, + "usage": { + "prompt_tokens": prompt_tokens, + "total_tokens": prompt_tokens, + "completion_tokens": 0, + "prompt_tokens_details": None, + }, + } + + return ClassifyResponse(**response)