137 lines
4.3 KiB
Python
137 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import time
|
|
from typing import TypeAlias
|
|
|
|
from pydantic import Field
|
|
|
|
from vllm import PoolingParams
|
|
from vllm.config import ModelConfig
|
|
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
|
from vllm.entrypoints.pooling.base.protocol import (
|
|
ChatRequestMixin,
|
|
CompletionRequestMixin,
|
|
EmbedRequestMixin,
|
|
PoolingBasicRequestMixin,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.renderers import TokenizeParams
|
|
from vllm.utils import random_uuid
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def _get_max_total_output_tokens(
|
|
model_config: ModelConfig,
|
|
) -> tuple[int | None, int]:
|
|
max_total_tokens = model_config.max_model_len
|
|
pooler_config = model_config.pooler_config
|
|
|
|
if pooler_config is None:
|
|
return max_total_tokens, 0
|
|
|
|
if pooler_config.enable_chunked_processing:
|
|
return None, 0
|
|
|
|
max_embed_len = pooler_config.max_embed_len or max_total_tokens
|
|
max_output_tokens = max_total_tokens - max_embed_len
|
|
return max_total_tokens, max_output_tokens
|
|
|
|
|
|
class EmbeddingCompletionRequest(
|
|
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
|
|
):
|
|
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
|
encoder_config = model_config.encoder_config or {}
|
|
|
|
(
|
|
max_total_tokens,
|
|
max_output_tokens,
|
|
) = _get_max_total_output_tokens(model_config)
|
|
|
|
return TokenizeParams(
|
|
max_total_tokens=max_total_tokens,
|
|
max_output_tokens=max_output_tokens,
|
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
do_lower_case=encoder_config.get("do_lower_case", False),
|
|
add_special_tokens=self.add_special_tokens,
|
|
max_total_tokens_param="max_model_len",
|
|
max_output_tokens_param="max_model_len - max_embed_len",
|
|
)
|
|
|
|
def to_pooling_params(self):
|
|
if self.normalize is not None:
|
|
logger.warning_once(
|
|
"`normalize` is deprecated and will be removed in v0.17. "
|
|
"Please pass `use_activation` instead."
|
|
)
|
|
self.use_activation = self.normalize
|
|
|
|
return PoolingParams(
|
|
task="embed",
|
|
dimensions=self.dimensions,
|
|
use_activation=self.use_activation,
|
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
)
|
|
|
|
|
|
class EmbeddingChatRequest(
|
|
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
|
|
):
|
|
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
|
encoder_config = model_config.encoder_config or {}
|
|
|
|
(
|
|
max_total_tokens,
|
|
max_output_tokens,
|
|
) = _get_max_total_output_tokens(model_config)
|
|
|
|
return TokenizeParams(
|
|
max_total_tokens=max_total_tokens,
|
|
max_output_tokens=max_output_tokens,
|
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
do_lower_case=encoder_config.get("do_lower_case", False),
|
|
add_special_tokens=self.add_special_tokens,
|
|
max_total_tokens_param="max_model_len",
|
|
max_output_tokens_param="max_model_len - max_embed_len",
|
|
)
|
|
|
|
def to_pooling_params(self):
|
|
if self.normalize is not None:
|
|
logger.warning_once(
|
|
"`normalize` is deprecated and will be removed in v0.17. "
|
|
"Please pass `use_activation` instead."
|
|
)
|
|
self.use_activation = self.normalize
|
|
|
|
return PoolingParams(
|
|
task="embed",
|
|
dimensions=self.dimensions,
|
|
use_activation=self.use_activation,
|
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
)
|
|
|
|
|
|
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
|
|
|
|
|
class EmbeddingResponseData(OpenAIBaseModel):
|
|
index: int
|
|
object: str = "embedding"
|
|
embedding: list[float] | str
|
|
|
|
|
|
class EmbeddingResponse(OpenAIBaseModel):
|
|
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
|
object: str = "list"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
model: str
|
|
data: list[EmbeddingResponseData]
|
|
usage: UsageInfo
|
|
|
|
|
|
class EmbeddingBytesResponse(OpenAIBaseModel):
|
|
content: list[bytes]
|
|
headers: dict[str, str] | None = None
|
|
media_type: str = "application/octet-stream"
|