This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter()
def classify(request: Request) -> ServingClassification | None:
return request.app.state.openai_serving_classification
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request):
handler = classify(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Classification API"
)
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
ClassifyRequestMixin,
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class ClassificationChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
)
class ClassificationData(OpenAIBaseModel):
index: int
label: str | None
probs: list[float]
num_classes: int
class ClassificationResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[ClassificationData]
usage: UsageInfo

View File

@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, TypeAlias
import jinja2
import numpy as np
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationData,
ClassificationRequest,
ClassificationResponse,
)
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
logger = init_logger(__name__)
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
self,
ctx: ClassificationServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
ctx.request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, ClassificationCompletionRequest):
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = id2label.get(predicted_index)
item = ClassificationData(
index=idx,
label=label,
probs=probs,
num_classes=len(probs),
)
items.append(item)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ClassificationResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> ClassificationResponse | ErrorResponse:
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]