# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from typing import TypeAlias from pydantic import Field from vllm import PoolingParams from vllm.config import ModelConfig from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( ChatRequestMixin, ClassifyRequestMixin, CompletionRequestMixin, PoolingBasicRequestMixin, ) from vllm.logger import init_logger from vllm.renderers import TokenizeParams from vllm.utils import random_uuid logger = init_logger(__name__) class ClassificationCompletionRequest( PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin ): def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: encoder_config = model_config.encoder_config or {} return TokenizeParams( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", ) def to_pooling_params(self): return PoolingParams( task="classify", truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, ) class ClassificationChatRequest( PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin ): def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: encoder_config = model_config.encoder_config or {} return TokenizeParams( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", ) def to_pooling_params(self): return PoolingParams( task="classify", truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, ) ClassificationRequest: TypeAlias = ( ClassificationCompletionRequest | ClassificationChatRequest ) class ClassificationData(OpenAIBaseModel): index: int label: str | None probs: list[float] num_classes: int class ClassificationResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"classify-{random_uuid()}") object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) model: str data: list[ClassificationData] usage: UsageInfo