Files
bi_150-vllm/vllm/entrypoints/pooling/embed/protocol.py

118 lines
3.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin,
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
)
class EmbeddingChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: list[float] | str
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[EmbeddingResponseData]
usage: UsageInfo
class EmbeddingBytesResponse(OpenAIBaseModel):
content: list[bytes]
headers: dict[str, str] | None = None
media_type: str = "application/octet-stream"