Sync from v0.13
This commit is contained in:
181
vllm/entrypoints/pooling/classify/protocol.py
Normal file
181
vllm/entrypoints/pooling/classify/protocol.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
)
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
input: list[str] | str
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:classification-extra-params]
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-classification-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:chat-classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
ClassificationRequest: TypeAlias = (
|
||||
ClassificationCompletionRequest | ClassificationChatRequest
|
||||
)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
index: int
|
||||
label: str | None
|
||||
probs: list[float]
|
||||
num_classes: int
|
||||
|
||||
|
||||
class ClassificationResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ClassificationData]
|
||||
usage: UsageInfo
|
||||
Reference in New Issue
Block a user