This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin,
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
class EmbeddingChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: list[float] | str
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[EmbeddingResponseData]
usage: UsageInfo
class EmbeddingBytesResponse(OpenAIBaseModel):
content: list[bytes]
headers: dict[str, str] | None = None
media_type: str = "application/octet-stream"