Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
121
vllm/entrypoints/pooling/__init__.py
Normal file
121
vllm/entrypoints/pooling/__init__.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
SupportedTask = object
|
||||
|
||||
|
||||
def register_pooling_api_routers(
|
||||
app: FastAPI, supported_tasks: tuple["SupportedTask", ...]
|
||||
):
|
||||
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
|
||||
|
||||
app.include_router(pooling_router)
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.api_router import (
|
||||
router as classify_router,
|
||||
)
|
||||
|
||||
app.include_router(classify_router)
|
||||
|
||||
if "embed" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
|
||||
|
||||
app.include_router(embed_router)
|
||||
|
||||
# Score/rerank endpoints are available for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
app.include_router(score_router)
|
||||
|
||||
|
||||
def init_pooling_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
state.openai_serving_pooling = (
|
||||
(
|
||||
OpenAIServingPooling(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
)
|
||||
if any(t in supported_tasks for t in POOLING_TASKS)
|
||||
else None
|
||||
)
|
||||
state.openai_serving_embedding = (
|
||||
OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "embed" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_classification = (
|
||||
ServingClassification(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "classify" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# ServingScores handles score/rerank for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
state.openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
|
||||
else None
|
||||
)
|
||||
0
vllm/entrypoints/pooling/base/__init__.py
Normal file
0
vllm/entrypoints/pooling/base/__init__.py
Normal file
207
vllm/entrypoints/pooling/base/protocol.py
Normal file
207
vllm/entrypoints/pooling/base/protocol.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.renderers import ChatParams, merge_kwargs
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
|
||||
class PoolingBasicRequestMixin(OpenAIBaseModel):
|
||||
# --8<-- [start:pooling-common-params]
|
||||
model: str | None = None
|
||||
user: str | None = None
|
||||
# --8<-- [end:pooling-common-params]
|
||||
|
||||
# --8<-- [start:pooling-common-extra-params]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs to pass to the HF processor.",
|
||||
)
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:pooling-common-extra-params]
|
||||
|
||||
|
||||
class CompletionRequestMixin(OpenAIBaseModel):
|
||||
# --8<-- [start:completion-params]
|
||||
input: list[int] | list[list[int]] | str | list[str]
|
||||
# --8<-- [end:completion-params]
|
||||
|
||||
# --8<-- [start:completion-extra-params]
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:completion-extra-params]
|
||||
|
||||
|
||||
class ChatRequestMixin(OpenAIBaseModel):
|
||||
# --8<-- [start:chat-params]
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
# --8<-- [end:chat-params]
|
||||
|
||||
# --8<-- [start:chat-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
'This allows you to "prefill" part of the model\'s response for it. '
|
||||
"Cannot be used at the same time as `add_generation_prompt`."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:chat-extra-params]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template or default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
continue_final_message=self.continue_final_message,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EncodingRequestMixin(OpenAIBaseModel):
|
||||
# --8<-- [start:encoding-params]
|
||||
encoding_format: EncodingFormat = "float"
|
||||
# --8<-- [end:encoding-params]
|
||||
|
||||
# --8<-- [start:encoding-extra-params]
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:encoding-extra-params]
|
||||
|
||||
|
||||
class EmbedRequestMixin(EncodingRequestMixin):
|
||||
# --8<-- [start:embed-params]
|
||||
dimensions: int | None = None
|
||||
# --8<-- [end:embed-params]
|
||||
|
||||
# --8<-- [start:embed-extra-params]
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for the pooler outputs. "
|
||||
"`None` uses the pooler's default, which is `True` in most cases.",
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated; please pass `use_activation` instead",
|
||||
)
|
||||
# --8<-- [end:embed-extra-params]
|
||||
|
||||
|
||||
class ClassifyRequestMixin(OpenAIBaseModel):
|
||||
# --8<-- [start:classify-extra-params]
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for the pooler outputs. "
|
||||
"`None` uses the pooler's default, which is `True` in most cases.",
|
||||
)
|
||||
# --8<-- [end:classify-extra-params]
|
||||
0
vllm/entrypoints/pooling/classify/__init__.py
Normal file
0
vllm/entrypoints/pooling/classify/__init__.py
Normal file
48
vllm/entrypoints/pooling/classify/api_router.py
Normal file
48
vllm/entrypoints/pooling/classify/api_router.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def classify(request: Request) -> ServingClassification | None:
|
||||
return request.app.state.openai_serving_classification
|
||||
|
||||
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
89
vllm/entrypoints/pooling/classify/protocol.py
Normal file
89
vllm/entrypoints/pooling/classify/protocol.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
ClassifyRequestMixin,
|
||||
CompletionRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationCompletionRequest(
|
||||
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
|
||||
):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(
|
||||
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
|
||||
):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
ClassificationRequest: TypeAlias = (
|
||||
ClassificationCompletionRequest | ClassificationChatRequest
|
||||
)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
index: int
|
||||
label: str | None
|
||||
probs: list[float]
|
||||
num_classes: int
|
||||
|
||||
|
||||
class ClassificationResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ClassificationData]
|
||||
usage: UsageInfo
|
||||
160
vllm/entrypoints/pooling/classify/serving.py
Normal file
160
vllm/entrypoints/pooling/classify/serving.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
if isinstance(ctx.request, ClassificationChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
id2label = getattr(self.model_config.hf_config, "id2label", {})
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
probs = classify_res.probs
|
||||
predicted_index = int(np.argmax(probs))
|
||||
label = id2label.get(predicted_index)
|
||||
|
||||
item = ClassificationData(
|
||||
index=idx,
|
||||
label=label,
|
||||
probs=probs,
|
||||
num_classes=len(probs),
|
||||
)
|
||||
|
||||
items.append(item)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
id=ctx.request_id,
|
||||
created=ctx.created_time,
|
||||
model=ctx.model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
0
vllm/entrypoints/pooling/embed/__init__.py
Normal file
0
vllm/entrypoints/pooling/embed/__init__.py
Normal file
82
vllm/entrypoints/pooling/embed/api_router.py
Normal file
82
vllm/entrypoints/pooling/embed/api_router.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from functools import lru_cache
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_json_response_cls():
|
||||
if importlib.util.find_spec("orjson") is not None:
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
return ORJSONResponse
|
||||
logger.warning_once(
|
||||
"To make v1/embeddings API fast, please install orjson by `pip install orjson`"
|
||||
)
|
||||
return JSONResponse
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding | None:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request,
|
||||
):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Embeddings API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_embedding(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return _get_json_response_cls()(content=generator.model_dump())
|
||||
elif isinstance(generator, EmbeddingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.content,
|
||||
headers=generator.headers,
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
136
vllm/entrypoints/pooling/embed/protocol.py
Normal file
136
vllm/entrypoints/pooling/embed/protocol.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
CompletionRequestMixin,
|
||||
EmbedRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[int | None, int]:
|
||||
max_total_tokens = model_config.max_model_len
|
||||
pooler_config = model_config.pooler_config
|
||||
|
||||
if pooler_config is None:
|
||||
return max_total_tokens, 0
|
||||
|
||||
if pooler_config.enable_chunked_processing:
|
||||
return None, 0
|
||||
|
||||
max_embed_len = pooler_config.max_embed_len or max_total_tokens
|
||||
max_output_tokens = max_total_tokens - max_embed_len
|
||||
return max_total_tokens, max_output_tokens
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(
|
||||
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
|
||||
):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
(
|
||||
max_total_tokens,
|
||||
max_output_tokens,
|
||||
) = _get_max_total_output_tokens(model_config)
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_model_len - max_embed_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.use_activation,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(
|
||||
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
|
||||
):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
(
|
||||
max_total_tokens,
|
||||
max_output_tokens,
|
||||
) = _get_max_total_output_tokens(model_config)
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_model_len - max_embed_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.use_activation,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: list[float] | str
|
||||
|
||||
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||
content: list[bytes]
|
||||
headers: dict[str, str] | None = None
|
||||
media_type: str = "application/octet-stream"
|
||||
640
vllm/entrypoints/pooling/embed/serving.py
Normal file
640
vllm/entrypoints/pooling/embed/serving.py
Normal file
@@ -0,0 +1,640 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||
from functools import partial
|
||||
from typing import Any, Final, Literal, TypeAlias, cast
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
)
|
||||
from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_bytes,
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
)
|
||||
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
pooler_config = self.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, EmbeddingCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def request_output_to_embed_json_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> EmbeddingResponse:
|
||||
encode_fn = cast(
|
||||
Callable[[PoolingRequestOutput], list[float] | str],
|
||||
(
|
||||
encode_pooling_output_float
|
||||
if encoding_format == "float"
|
||||
else partial(
|
||||
encode_pooling_output_base64,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
items: list[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = EmbeddingResponseData(
|
||||
index=idx,
|
||||
embedding=encode_fn(final_res),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_embed_bytes_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["bytes", "bytes_only"],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> EmbeddingBytesResponse:
|
||||
content, items, usage = encode_pooling_bytes(
|
||||
pooling_outputs=final_res_batch,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
|
||||
headers = (
|
||||
None
|
||||
if encoding_format == "bytes_only"
|
||||
else {
|
||||
"metadata": json.dumps(
|
||||
{
|
||||
"id": request_id,
|
||||
"created": created_time,
|
||||
"model": model_name,
|
||||
"data": items,
|
||||
"usage": usage,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return EmbeddingBytesResponse(content=content, headers=headers)
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
|
||||
encoding_format = ctx.request.encoding_format
|
||||
embed_dtype = ctx.request.embed_dtype
|
||||
endianness = ctx.request.endianness
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self.request_output_to_embed_json_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
ctx.model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self.request_output_to_embed_bytes_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
ctx.model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _get_max_position_embeddings(self) -> int:
|
||||
"""Get the model's effective maximum sequence length for chunking."""
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def _should_use_chunked_processing(self, request) -> bool:
|
||||
"""Check if chunked processing should be used for this request."""
|
||||
return (
|
||||
isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
|
||||
and self.pooler_config.enable_chunked_processing
|
||||
)
|
||||
|
||||
async def _process_chunked_request(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
token_ids: list[int],
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
prompt_idx: int,
|
||||
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
|
||||
"""Process a single prompt using chunked processing."""
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
# Split into chunks using max_position_embeddings
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
# Process all chunks for MEAN aggregation
|
||||
for chunk_idx, chunk_tokens in enumerate(
|
||||
chunk_list(token_ids, max_pos_embeddings)
|
||||
):
|
||||
# Create a request ID for this chunk
|
||||
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
|
||||
|
||||
# Create engine prompt for this chunk
|
||||
chunk_engine_prompt = token_inputs(chunk_tokens)
|
||||
|
||||
# Log the chunk
|
||||
self._log_inputs(
|
||||
chunk_request_id,
|
||||
chunk_engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=ctx.request.priority,
|
||||
)
|
||||
|
||||
generators.append(original_generator)
|
||||
|
||||
return generators
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: object,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TokensPrompt:
|
||||
"""Override to support chunked processing for embedding requests."""
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
|
||||
# Check if chunked processing is enabled for pooling models
|
||||
enable_chunked = self._should_use_chunked_processing(request)
|
||||
|
||||
# Use max_position_embeddings for chunked processing decisions
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
# Determine the effective max length for validation
|
||||
if self.pooler_config.max_embed_len:
|
||||
# Use max_embed_len for validation instead of max_model_len
|
||||
length_type = "maximum embedding input length"
|
||||
max_length_value = self.pooler_config.max_embed_len
|
||||
else:
|
||||
# Fall back to max_model_len validation (original behavior)
|
||||
length_type = "maximum context length"
|
||||
max_length_value = self.model_config.max_model_len
|
||||
|
||||
validation_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input."
|
||||
)
|
||||
|
||||
chunked_processing_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input "
|
||||
"or enable chunked processing."
|
||||
)
|
||||
|
||||
# Check if input exceeds max length
|
||||
if token_num > max_length_value:
|
||||
raise ValueError(
|
||||
validation_error_msg.format(
|
||||
length_type=length_type,
|
||||
max_length_value=max_length_value,
|
||||
token_num=token_num,
|
||||
)
|
||||
)
|
||||
|
||||
# Check for chunked processing
|
||||
# when exceeding max_position_embeddings
|
||||
if token_num > max_pos_embeddings:
|
||||
if enable_chunked:
|
||||
# Allow long inputs when chunked processing is enabled
|
||||
logger.info(
|
||||
"Input length %s exceeds max_position_embeddings "
|
||||
"%s, will use chunked processing",
|
||||
token_num,
|
||||
max_pos_embeddings,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
chunked_processing_error_msg.format(
|
||||
length_type="maximum position embeddings length",
|
||||
max_length_value=max_pos_embeddings,
|
||||
token_num=token_num,
|
||||
)
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
# For other request types, use the parent's implementation
|
||||
return super()._validate_input(request, input_ids, input_text)
|
||||
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: ProcessorInputs,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
prompt_index: int,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Create a generator for a single prompt using standard processing."""
|
||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=ctx.request.priority,
|
||||
)
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Override to support chunked processing."""
|
||||
# Check if we should use chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
# If no chunked processing needed, delegate to parent class
|
||||
if not use_chunked:
|
||||
return await super()._prepare_generators(ctx)
|
||||
|
||||
# Custom logic for chunked processing
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if "prompt_token_ids" in engine_prompt:
|
||||
prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
|
||||
|
||||
if len(prompt_token_ids) > max_pos_embeddings:
|
||||
# Use chunked processing for this prompt
|
||||
chunk_generators = await self._process_chunked_request(
|
||||
ctx,
|
||||
prompt_token_ids,
|
||||
pooling_params,
|
||||
trace_headers,
|
||||
i,
|
||||
)
|
||||
generators.extend(chunk_generators)
|
||||
continue
|
||||
|
||||
# Normal processing for short prompts or non-token prompts
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt, pooling_params, trace_headers, i
|
||||
)
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Collect and aggregate batch results
|
||||
with support for chunked processing.
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
minimize memory usage.
|
||||
For regular requests, collects results normally.
|
||||
"""
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
# Check if we used chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
if not use_chunked:
|
||||
return await super()._collect_batch(ctx=ctx)
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response("Result generator not available")
|
||||
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
prompt_aggregators: dict[int, dict[str, Any]] = {}
|
||||
short_prompts_results: dict[int, PoolingRequestOutput] = {}
|
||||
|
||||
async for result_idx, result in ctx.result_generator:
|
||||
if "-chunk-" in result.request_id:
|
||||
# Extract prompt_idx from chunked request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
prompt_idx = int(parts[parts.index("prompt") + 1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback: extract from result_idx if parsing fails
|
||||
prompt_idx = result_idx
|
||||
|
||||
# Initialize aggregator for this prompt if needed
|
||||
if prompt_idx not in prompt_aggregators:
|
||||
prompt_aggregators[prompt_idx] = {
|
||||
"weighted_sum": None,
|
||||
"total_weight": 0,
|
||||
"chunk_count": 0,
|
||||
"request_id": result.request_id.split("-chunk-")[0],
|
||||
}
|
||||
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
# MEAN pooling with online weighted averaging
|
||||
# Ensure result is PoolingRequestOutput
|
||||
# for embedding processing
|
||||
if not isinstance(result, PoolingRequestOutput):
|
||||
return self.create_error_response(
|
||||
f"Expected PoolingRequestOutput for "
|
||||
f"chunked embedding, got "
|
||||
f"{type(result).__name__}"
|
||||
)
|
||||
|
||||
# Handle both PoolingOutput and
|
||||
# EmbeddingOutput types
|
||||
if hasattr(result.outputs, "data"):
|
||||
# PoolingOutput case
|
||||
embedding_data = result.outputs.data
|
||||
elif hasattr(result.outputs, "embedding"):
|
||||
# EmbeddingOutput case -
|
||||
# convert embedding list to tensor
|
||||
embedding_data = result.outputs.embedding
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Unsupported output type: {type(result.outputs).__name__}"
|
||||
)
|
||||
|
||||
if not isinstance(embedding_data, torch.Tensor):
|
||||
embedding_data = torch.tensor(
|
||||
embedding_data, dtype=torch.float32
|
||||
)
|
||||
|
||||
if result.prompt_token_ids is None:
|
||||
return self.create_error_response(
|
||||
"prompt_token_ids cannot be None for chunked processing"
|
||||
)
|
||||
weight = len(result.prompt_token_ids)
|
||||
|
||||
weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
|
||||
|
||||
if aggregator["weighted_sum"] is None:
|
||||
# First chunk
|
||||
aggregator["weighted_sum"] = weighted_embedding
|
||||
else:
|
||||
# Accumulate
|
||||
aggregator["weighted_sum"] += weighted_embedding
|
||||
|
||||
aggregator["total_weight"] += weight
|
||||
aggregator["chunk_count"] += 1
|
||||
else:
|
||||
# Non-chunked result - extract prompt_idx from request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
# Last part should be prompt index
|
||||
prompt_idx = int(parts[-1])
|
||||
except (ValueError, IndexError):
|
||||
prompt_idx = result_idx # Fallback to result_idx
|
||||
|
||||
short_prompts_results[prompt_idx] = result
|
||||
|
||||
# Finalize aggregated results
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if prompt_idx in prompt_aggregators:
|
||||
# Finalize MEAN aggregation for this chunked prompt
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
weighted_sum = aggregator["weighted_sum"]
|
||||
total_weight = aggregator["total_weight"]
|
||||
|
||||
if (
|
||||
weighted_sum is not None
|
||||
and isinstance(weighted_sum, torch.Tensor)
|
||||
and isinstance(total_weight, (int, float))
|
||||
and total_weight > 0
|
||||
):
|
||||
# Compute final mean embedding
|
||||
final_embedding = weighted_sum / total_weight
|
||||
|
||||
# Create a PoolingRequestOutput
|
||||
# for the aggregated result
|
||||
pooling_output_data = PoolingOutput(data=final_embedding)
|
||||
|
||||
# Get original prompt token IDs for this prompt
|
||||
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||
if "prompt_token_ids" not in original_prompt:
|
||||
return self.create_error_response(
|
||||
f"Chunked prompt {prompt_idx} does not contain "
|
||||
"token IDs"
|
||||
)
|
||||
|
||||
original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
|
||||
|
||||
pooling_request_output = PoolingRequestOutput(
|
||||
request_id=aggregator["request_id"],
|
||||
prompt_token_ids=original_token_ids,
|
||||
outputs=pooling_output_data,
|
||||
num_cached_tokens=0,
|
||||
finished=True,
|
||||
)
|
||||
|
||||
final_res_batch.append(pooling_request_output)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Failed to aggregate chunks for prompt {prompt_idx}"
|
||||
)
|
||||
elif prompt_idx in short_prompts_results:
|
||||
final_res_batch.append(short_prompts_results[prompt_idx])
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Result not found for prompt {prompt_idx}"
|
||||
)
|
||||
|
||||
ctx.final_res_batch = final_res_batch
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> EmbeddingResponse | ErrorResponse:
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}"
|
||||
)
|
||||
|
||||
ctx = EmbeddingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
0
vllm/entrypoints/pooling/pooling/__init__.py
Normal file
0
vllm/entrypoints/pooling/pooling/__init__.py
Normal file
62
vllm/entrypoints/pooling/pooling/api_router.py
Normal file
62
vllm/entrypoints/pooling/pooling/api_router.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def pooling(request: Request) -> OpenAIServingPooling | None:
|
||||
return request.app.state.openai_serving_pooling
|
||||
|
||||
|
||||
@router.post(
|
||||
"/pooling",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
handler = pooling(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Pooling API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.create_pooling(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
elif isinstance(generator, PoolingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.content,
|
||||
headers=generator.headers,
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
153
vllm/entrypoints/pooling/pooling/protocol.py
Normal file
153
vllm/entrypoints/pooling/pooling/protocol.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Generic, TypeAlias, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
ClassifyRequestMixin,
|
||||
CompletionRequestMixin,
|
||||
EmbedRequestMixin,
|
||||
EncodingRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PoolingCompletionRequest(
|
||||
PoolingBasicRequestMixin,
|
||||
CompletionRequestMixin,
|
||||
EmbedRequestMixin,
|
||||
ClassifyRequestMixin,
|
||||
):
|
||||
task: PoolingTask | None = None
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task=self.task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
|
||||
|
||||
class PoolingChatRequest(
|
||||
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin, ClassifyRequestMixin
|
||||
):
|
||||
task: PoolingTask | None = None
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task=self.task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic[T]):
|
||||
data: T
|
||||
task: PoolingTask = "plugin"
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
request_id: str | None = None
|
||||
"""
|
||||
The request_id associated with this response
|
||||
"""
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual output is generated
|
||||
by the plugin itself. Hence, we use a generic type for the response data
|
||||
"""
|
||||
|
||||
|
||||
PoolingRequest: TypeAlias = (
|
||||
PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
|
||||
)
|
||||
|
||||
|
||||
class PoolingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "pooling"
|
||||
data: list[list[float]] | list[float] | str
|
||||
|
||||
|
||||
class PoolingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[PoolingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class PoolingBytesResponse(OpenAIBaseModel):
|
||||
content: list[bytes]
|
||||
headers: dict[str, str] | None = None
|
||||
media_type: str = "application/octet-stream"
|
||||
356
vllm/entrypoints/pooling/pooling/serving.py
Normal file
356
vllm/entrypoints/pooling/pooling/serving.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
PoolingResponseData,
|
||||
)
|
||||
from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_bytes,
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
)
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.inputs.preprocess import prompt_to_seq
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = self.models.model_name()
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported"
|
||||
)
|
||||
|
||||
engine_prompts: Sequence[ProcessorInputs]
|
||||
if use_io_processor := isinstance(request, IOProcessorRequest):
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
validated_prompt = self.io_processor.parse_data(request.data)
|
||||
|
||||
raw_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id
|
||||
)
|
||||
engine_prompts = await self._preprocess_cmpl(
|
||||
request,
|
||||
prompt_to_seq(raw_prompts),
|
||||
)
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
if use_io_processor:
|
||||
assert self.io_processor is not None
|
||||
|
||||
pooling_params = self.io_processor.merge_pooling_params()
|
||||
if pooling_params.task is None:
|
||||
pooling_params.task = "plugin"
|
||||
else:
|
||||
pooling_params = request.to_pooling_params() # type: ignore
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if use_io_processor:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if callable(
|
||||
output_to_response := getattr(
|
||||
self.io_processor, "output_to_response", None
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"`IOProcessor.output_to_response` is deprecated. To ensure "
|
||||
"consistency between offline and online APIs, "
|
||||
"`IOProcessorResponse` will become a transparent wrapper "
|
||||
"around output data from v0.19 onwards.",
|
||||
)
|
||||
|
||||
if hasattr(output, "request_id") and output.request_id is None:
|
||||
output.request_id = request_id # type: ignore
|
||||
|
||||
return output_to_response(output) # type: ignore
|
||||
|
||||
return IOProcessorResponse(request_id=request_id, data=output)
|
||||
|
||||
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
request.encoding_format,
|
||||
request.embed_dtype,
|
||||
request.endianness,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_pooling_json_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingResponse:
|
||||
encode_fn = cast(
|
||||
Callable[[PoolingRequestOutput], list[float] | str],
|
||||
(
|
||||
encode_pooling_output_float
|
||||
if encoding_format == "float"
|
||||
else partial(
|
||||
encode_pooling_output_base64,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
items: list[PoolingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = PoolingResponseData(
|
||||
index=idx,
|
||||
data=encode_fn(final_res),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return PoolingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_pooling_bytes_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["bytes", "bytes_only"],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingBytesResponse:
|
||||
content, items, usage = encode_pooling_bytes(
|
||||
pooling_outputs=final_res_batch,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
|
||||
headers = (
|
||||
None
|
||||
if encoding_format == "bytes_only"
|
||||
else {
|
||||
"metadata": json.dumps(
|
||||
{
|
||||
"id": request_id,
|
||||
"created": created_time,
|
||||
"model": model_name,
|
||||
"data": items,
|
||||
"usage": usage,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return PoolingBytesResponse(content=content, headers=headers)
|
||||
|
||||
def request_output_to_pooling_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: EncodingFormat,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingResponse | PoolingBytesResponse:
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self.request_output_to_pooling_json_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self.request_output_to_pooling_bytes_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
assert_never(encoding_format)
|
||||
0
vllm/entrypoints/pooling/score/__init__.py
Normal file
0
vllm/entrypoints/pooling/score/__init__.py
Normal file
147
vllm/entrypoints/pooling/score/api_router.py
Normal file
147
vllm/entrypoints/pooling/score/api_router.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def score(request: Request) -> ServingScores | None:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> ServingScores | None:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
@router.post(
|
||||
"/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Score API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||
"have moved it to `/score`. Please update your client accordingly."
|
||||
)
|
||||
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Rerank (Score) API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning_once(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client "
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)"
|
||||
)
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
156
vllm/entrypoints/pooling/score/protocol.py
Normal file
156
vllm/entrypoints/pooling/score/protocol.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ClassifyRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreInput,
|
||||
ScoreInputs,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||
return PoolingParams(
|
||||
task=task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
class ScoreDataRequest(ScoreRequestMixin):
|
||||
data_1: ScoreInputs
|
||||
data_2: ScoreInputs
|
||||
|
||||
|
||||
class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
|
||||
queries: ScoreInputs
|
||||
documents: ScoreInputs
|
||||
|
||||
@property
|
||||
def data_1(self):
|
||||
return self.queries
|
||||
|
||||
@property
|
||||
def data_2(self):
|
||||
return self.documents
|
||||
|
||||
|
||||
class ScoreQueriesItemsRequest(ScoreRequestMixin):
|
||||
queries: ScoreInputs
|
||||
items: ScoreInputs
|
||||
|
||||
@property
|
||||
def data_1(self):
|
||||
return self.queries
|
||||
|
||||
@property
|
||||
def data_2(self):
|
||||
return self.items
|
||||
|
||||
|
||||
class ScoreTextRequest(ScoreRequestMixin):
|
||||
text_1: ScoreInputs
|
||||
text_2: ScoreInputs
|
||||
|
||||
@property
|
||||
def data_1(self):
|
||||
return self.text_1
|
||||
|
||||
@property
|
||||
def data_2(self):
|
||||
return self.text_2
|
||||
|
||||
|
||||
ScoreRequest: TypeAlias = (
|
||||
ScoreQueriesDocumentsRequest
|
||||
| ScoreQueriesItemsRequest
|
||||
| ScoreDataRequest
|
||||
| ScoreTextRequest
|
||||
)
|
||||
|
||||
|
||||
class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
query: ScoreInput
|
||||
documents: ScoreInputs
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||
return PoolingParams(
|
||||
task=task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str | None = None
|
||||
multi_modal: list[ScoreContentPartParam] | None = None
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
index: int
|
||||
document: RerankDocument
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RerankUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankResponse(OpenAIBaseModel):
|
||||
id: str
|
||||
model: str
|
||||
usage: RerankUsage
|
||||
results: list[RerankResult]
|
||||
|
||||
|
||||
class ScoreResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "score"
|
||||
score: float
|
||||
|
||||
|
||||
class ScoreResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ScoreResponseData]
|
||||
usage: UsageInfo
|
||||
651
vllm/entrypoints/pooling/score/serving.py
Normal file
651
vllm/entrypoints/pooling/score/serving.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankDocument,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResult,
|
||||
RerankUsage,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
parse_score_data_single,
|
||||
validate_score_input,
|
||||
)
|
||||
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.score_template = score_template
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.is_cross_encoder = self.model_config.is_cross_encoder
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_late_interaction = self.model_config.is_late_interaction
|
||||
|
||||
if self.is_cross_encoder:
|
||||
self._score_func = self._cross_encoding_score
|
||||
elif self.is_late_interaction:
|
||||
self._score_func = self._late_interaction_score
|
||||
else:
|
||||
self._score_func = self._embedding_score
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
input_texts: list[str] = []
|
||||
for text in data_1 + data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Embedding scores currently do not support multimodal input."
|
||||
)
|
||||
input_texts.append(text)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
encode_async = make_async(
|
||||
tokenizer.encode,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[ProcessorInputs] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
token_inputs(
|
||||
text_token_prompt["prompt_token_ids"],
|
||||
prompt=input_text,
|
||||
)
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params("embed")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(
|
||||
tokenizer=tokenizer, embed_1=emb_data_1, embed_2=emb_data_2
|
||||
)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
def _preprocess_late_interaction_item(
|
||||
self,
|
||||
data: ScoreData,
|
||||
role: str,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
"""Parse a single ScoreData into a text + optional multimodal
|
||||
TokensPrompt for late-interaction encoding.
|
||||
|
||||
For plain strings, tokenises directly.
|
||||
For multimodal content parts, extracts text and multi_modal_data.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
|
||||
if isinstance(data, str):
|
||||
text, mm_data, mm_uuids = data, None, None
|
||||
else:
|
||||
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
|
||||
|
||||
prompt_inputs = tokenizer(text, **tokenization_kwargs)
|
||||
self._validate_input(request, prompt_inputs["input_ids"], text)
|
||||
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
)
|
||||
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return text, engine_prompt
|
||||
|
||||
async def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
|
||||
all_data = data_1 + data_2
|
||||
roles = ["query"] * len(data_1) + ["document"] * len(data_2)
|
||||
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_late_interaction_item,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
preprocessed = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
data=d,
|
||||
role=r,
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
for d, r in zip(all_data, roles)
|
||||
)
|
||||
)
|
||||
|
||||
input_texts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for text, engine_prompt in preprocessed:
|
||||
input_texts.append(text)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
pooling_params = request.to_pooling_params("token_embed")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Collect token embeddings
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
# Split into query and document embeddings
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_score,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2,
|
||||
)
|
||||
for t1, t2 in input_pairs
|
||||
)
|
||||
)
|
||||
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for full_prompt, engine_prompt in preprocessed_prompts:
|
||||
request_prompts.append(full_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
default_pooling_params = request.to_pooling_params("score")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
request_prompts[i],
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
else:
|
||||
pooling_params = default_pooling_params
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
|
||||
engine_prompts
|
||||
)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
model_config = self.model_config
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=data_1,
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=self.score_template,
|
||||
)
|
||||
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
data_1: ScoreInputs,
|
||||
data_2: ScoreInputs,
|
||||
request: ScoreRequest | RerankRequest,
|
||||
request_id: str,
|
||||
raw_request: Request | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
score_data_1, score_data_2 = validate_score_input(
|
||||
data_1,
|
||||
data_2,
|
||||
is_multimodal_model=self.is_multimodal_model,
|
||||
architecture=self.architecture,
|
||||
)
|
||||
|
||||
return await self._score_func(
|
||||
data_1=score_data_1,
|
||||
data_2=score_data_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> ScoreResponse | ErrorResponse:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.data_1,
|
||||
request.data_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self.models.model_name(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def do_rerank(
|
||||
self, request: RerankRequest, raw_request: Request | None = None
|
||||
) -> RerankResponse | ErrorResponse:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
top_n = request.top_n if request.top_n > 0 else len(final_res_batch)
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self.models.model_name(),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
documents: ScoreInputs,
|
||||
top_n: int,
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
|
||||
if not isinstance(documents, list):
|
||||
documents = [documents]
|
||||
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
document = documents[idx]
|
||||
if isinstance(document, str):
|
||||
rerank_document = RerankDocument(text=document)
|
||||
else:
|
||||
rerank_document = RerankDocument(
|
||||
multi_modal=document.get("content", [])
|
||||
)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=rerank_document,
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(
|
||||
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
|
||||
),
|
||||
)
|
||||
404
vllm/entrypoints/pooling/score/utils.py
Normal file
404
vllm/entrypoints/pooling/score/utils.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import CosineSimilarity
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
BaseMultiModalItemTracker,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionContentPartVideoParam,
|
||||
ChatTemplateResolutionError,
|
||||
ConversationMessage,
|
||||
MultiModalItemTracker,
|
||||
_parse_chat_message_content_parts,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import PromptType, TextPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
ScoreContentPartParam: TypeAlias = (
|
||||
ChatCompletionContentPartImageParam
|
||||
| ChatCompletionContentPartImageEmbedsParam
|
||||
| ChatCompletionContentPartTextParam
|
||||
| ChatCompletionContentPartVideoParam
|
||||
)
|
||||
|
||||
|
||||
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute ColBERT MaxSim score.
|
||||
|
||||
Args:
|
||||
q_emb: Query token embeddings [query_len, dim]
|
||||
d_emb: Document token embeddings [doc_len, dim]
|
||||
|
||||
Returns:
|
||||
MaxSim score (sum over query tokens of max similarity to any doc token)
|
||||
"""
|
||||
# [query_len, doc_len]
|
||||
token_scores = torch.matmul(q_emb, d_emb.T)
|
||||
# Max over document tokens, sum over query tokens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
|
||||
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
|
||||
2. Including chat-specific fields would confuse users about their purpose in scoring
|
||||
3. This is a more focused interface that only exposes what's needed for scoring
|
||||
""" # noqa: E501
|
||||
|
||||
content: Required[list[ScoreContentPartParam]]
|
||||
"""The multimodal contents"""
|
||||
|
||||
|
||||
# Raw input data with content key in ScoreMultiModalParam.
|
||||
ScoreInput = str | ScoreMultiModalParam
|
||||
ScoreInputs = ScoreInput | list[ScoreInput]
|
||||
# Score data without content key.
|
||||
ScoreData = str | list[ScoreContentPartParam]
|
||||
|
||||
|
||||
def _cosine_similarity(
|
||||
tokenizer: TokenizerLike,
|
||||
embed_1: list[PoolingRequestOutput],
|
||||
embed_2: list[PoolingRequestOutput],
|
||||
) -> list[PoolingRequestOutput]:
|
||||
scorer = CosineSimilarity(0)
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
|
||||
for emb_1, emb_2 in zip(embed_1, embed_2):
|
||||
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
|
||||
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _validate_score_input_lens(
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
):
|
||||
len_1 = len(data_1)
|
||||
len_2 = len(data_2)
|
||||
|
||||
if len_1 > 1 and len_1 != len_2:
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len_1 == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len_2 == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
|
||||
def _validate_mm_score_input(
|
||||
data: list[ScoreInput],
|
||||
is_multimodal_model: bool,
|
||||
architecture: str,
|
||||
) -> list[ScoreData]:
|
||||
out: list[ScoreData] = []
|
||||
for d in data:
|
||||
if isinstance(d, str):
|
||||
out.append(d)
|
||||
else:
|
||||
if not is_multimodal_model:
|
||||
raise ValueError(f"MultiModalParam is not supported for {architecture}")
|
||||
content = cast(list[ScoreContentPartParam], d.get("content", []))
|
||||
out.append(content)
|
||||
return out
|
||||
|
||||
|
||||
def validate_score_input(
|
||||
data_1: ScoreInputs,
|
||||
data_2: ScoreInputs,
|
||||
is_multimodal_model: bool,
|
||||
architecture: str,
|
||||
) -> tuple[list[ScoreData], list[ScoreData]]:
|
||||
if not isinstance(data_1, list):
|
||||
data_1 = [data_1]
|
||||
|
||||
if not isinstance(data_2, list):
|
||||
data_2 = [data_2]
|
||||
|
||||
score_input_1 = _validate_mm_score_input(data_1, is_multimodal_model, architecture)
|
||||
score_input_2 = _validate_mm_score_input(data_2, is_multimodal_model, architecture)
|
||||
_validate_score_input_lens(score_input_1, score_input_2)
|
||||
return score_input_1, score_input_2
|
||||
|
||||
|
||||
def _ensure_str(content: list[ConversationMessage]) -> str:
|
||||
"""Extract a single string prompt from parsed conversation content."""
|
||||
assert len(content) == 1
|
||||
prompt = content[0]["content"]
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
return cast(str, prompt)
|
||||
raise ValueError(f"Only string content is supported, but got {content}.")
|
||||
|
||||
|
||||
def parse_score_data(
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
"""Parse a query-document pair into text prompts and shared multi-modal
|
||||
data.
|
||||
|
||||
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
|
||||
items from both inputs are merged into one ``mm_data`` dict. This is
|
||||
the correct behaviour for cross-encoder scoring, where query and
|
||||
document are concatenated into a single model prompt.
|
||||
"""
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
content_1 = _parse_score_content("query", data_1, mm_tracker)
|
||||
content_2 = _parse_score_content("document", data_2, mm_tracker)
|
||||
|
||||
prompt_1 = _ensure_str(content_1)
|
||||
prompt_2 = _ensure_str(content_2)
|
||||
mm_items, mm_uuids = mm_tracker.resolve_items()
|
||||
|
||||
return prompt_1, prompt_2, mm_items, mm_uuids
|
||||
|
||||
|
||||
def parse_score_data_single(
|
||||
data: ScoreData,
|
||||
role: str,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
"""Parse **one** ScoreData into a text prompt and its own multi-modal
|
||||
data.
|
||||
|
||||
Unlike :func:`parse_score_data`, each call creates an **independent**
|
||||
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
|
||||
This is the correct behaviour for late-interaction scoring, where
|
||||
query and document are encoded independently.
|
||||
"""
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
content = _parse_score_content(role, data, mm_tracker)
|
||||
|
||||
prompt = _ensure_str(content)
|
||||
mm_items, mm_uuids = mm_tracker.resolve_items()
|
||||
return prompt, mm_items, mm_uuids
|
||||
|
||||
|
||||
def score_data_to_prompts(
|
||||
data_list: list[ScoreData],
|
||||
role: str,
|
||||
model_config: ModelConfig,
|
||||
) -> list[PromptType]:
|
||||
"""Convert a list of ScoreData into PromptType objects.
|
||||
|
||||
For plain text inputs, returns the string directly.
|
||||
For multimodal inputs (list of content parts), parses them into
|
||||
a :class:`TextPrompt` with attached ``multi_modal_data`` /
|
||||
``multi_modal_uuids``.
|
||||
|
||||
This is used by late-interaction scoring where each query/document
|
||||
is encoded independently.
|
||||
"""
|
||||
prompts: list[PromptType] = []
|
||||
for data in data_list:
|
||||
if isinstance(data, str):
|
||||
prompts.append(data)
|
||||
else:
|
||||
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
|
||||
prompt: TextPrompt = TextPrompt(prompt=text)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
|
||||
|
||||
def _parse_score_content(
|
||||
role: str,
|
||||
data: ScoreData,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> list[ConversationMessage]:
|
||||
parts: Iterable[ChatCompletionContentPartParam]
|
||||
if isinstance(data, str):
|
||||
parts = [ChatCompletionContentPartTextParam(type="text", text=data)]
|
||||
else:
|
||||
parts = cast(Iterable[ChatCompletionContentPartParam], data)
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
|
||||
parse_res = _parse_chat_message_content_parts(
|
||||
role=role,
|
||||
parts=parts,
|
||||
mm_tracker=mm_tracker,
|
||||
wrap_dicts=False,
|
||||
interleave_strings=False,
|
||||
)
|
||||
|
||||
if parse_res:
|
||||
return parse_res
|
||||
|
||||
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
|
||||
|
||||
if (
|
||||
len(mm_placeholder_storage) != 1
|
||||
or len(next(iter(mm_placeholder_storage.values()))) != 1
|
||||
):
|
||||
raise ValueError("Only one multi-modal item is supported")
|
||||
|
||||
return next(iter(mm_placeholder_storage.values()))[0]
|
||||
|
||||
|
||||
def _apply_model_score_template(
|
||||
model_config: ModelConfig, prompt_1: str, prompt_2: str
|
||||
) -> str:
|
||||
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
if supports_score_template(model):
|
||||
full_prompt = model.get_score_template(prompt_1, prompt_2)
|
||||
if full_prompt is None:
|
||||
raise ValueError("Get empty score template from model")
|
||||
return full_prompt
|
||||
|
||||
raise ValueError(f"Unsupported model architecture: {model_config.architecture}")
|
||||
|
||||
|
||||
def post_process_tokens(
|
||||
model_config: ModelConfig,
|
||||
prompt: TokensPrompt,
|
||||
) -> None:
|
||||
"""
|
||||
Perform architecture-specific manipulations on the input tokens.
|
||||
|
||||
Note:
|
||||
This is an in-place operation.
|
||||
"""
|
||||
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
if supports_score_template(model):
|
||||
model.post_process_tokens(prompt)
|
||||
|
||||
|
||||
def get_score_prompt(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
score_template: str | None = None,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
|
||||
data_1,
|
||||
data_2,
|
||||
model_config,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
|
||||
def default_tokenizer_encode():
|
||||
if supports_score_template(model):
|
||||
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
else:
|
||||
if model_config.use_sep_token:
|
||||
# cross_encoder models defaults to using separating token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
else:
|
||||
# `llm as reranker` defaults to not using separating token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||
return full_prompt, prompt_inputs
|
||||
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
if score_template is None:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
else:
|
||||
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
|
||||
# If that fails because there is no such template,
|
||||
# fall back to the default implementation.
|
||||
try:
|
||||
full_prompt = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
[
|
||||
{"role": "query", "content": prompt_1},
|
||||
{"role": "document", "content": prompt_2},
|
||||
],
|
||||
chat_template=score_template,
|
||||
tools=None,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
except ChatTemplateResolutionError:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
||||
|
||||
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
|
||||
engine_prompt["token_type_ids"] = token_type_ids
|
||||
|
||||
post_process_tokens(model_config, engine_prompt)
|
||||
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
|
||||
def compress_token_type_ids(token_type_ids: list[int]) -> int:
|
||||
"""
|
||||
Return position of the first 1 or the length of the list
|
||||
if not found.
|
||||
"""
|
||||
first_one = len(token_type_ids)
|
||||
err_msg = (
|
||||
"Token type ids are expected to be a sequence"
|
||||
" of zeros followed by a sequence of ones"
|
||||
)
|
||||
for i, type_id in enumerate(token_type_ids):
|
||||
if type_id == 0 and first_one < i:
|
||||
raise ValueError(err_msg)
|
||||
elif type_id == 1 and first_one > i:
|
||||
first_one = i
|
||||
elif type_id > 1:
|
||||
raise ValueError(err_msg)
|
||||
|
||||
return first_one
|
||||
124
vllm/entrypoints/pooling/utils.py
Normal file
124
vllm/entrypoints/pooling/utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPES,
|
||||
EmbedDType,
|
||||
Endianness,
|
||||
binary2tensor,
|
||||
tensor2binary,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataItem:
|
||||
index: int
|
||||
embed_dtype: EmbedDType
|
||||
endianness: Endianness
|
||||
start: int
|
||||
end: int
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
def build_metadata_items(
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
shape: tuple[int, ...],
|
||||
n_request: int,
|
||||
) -> list[MetadataItem]:
|
||||
n_bytes = EMBED_DTYPES[embed_dtype].nbytes
|
||||
size = math.prod(shape)
|
||||
|
||||
return [
|
||||
MetadataItem(
|
||||
index=i,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
start=i * size * n_bytes,
|
||||
end=(i + 1) * size * n_bytes,
|
||||
shape=shape,
|
||||
)
|
||||
for i in range(n_request)
|
||||
]
|
||||
|
||||
|
||||
def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
|
||||
return output.outputs.data.tolist()
|
||||
|
||||
|
||||
def encode_pooling_output_binary(
|
||||
output: PoolingRequestOutput,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> bytes:
|
||||
return tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||
|
||||
|
||||
def encode_pooling_output_base64(
|
||||
output: PoolingRequestOutput,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> str:
|
||||
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||
return pybase64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def encode_pooling_bytes(
|
||||
pooling_outputs: list[PoolingRequestOutput],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> tuple[list[bytes], list[dict[str, Any]], dict[str, Any]]:
|
||||
num_prompt_tokens = 0
|
||||
items: list[dict[str, Any]] = []
|
||||
body: list[bytes] = []
|
||||
offset = 0
|
||||
for idx, output in enumerate(pooling_outputs):
|
||||
binary = tensor2binary(
|
||||
tensor=output.outputs.data,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
size = len(binary)
|
||||
|
||||
# Dictionary form of MetadataItem
|
||||
item = dict(
|
||||
index=idx,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
start=offset,
|
||||
end=offset + size,
|
||||
shape=output.outputs.data.shape,
|
||||
)
|
||||
|
||||
body.append(binary)
|
||||
items.append(item)
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
offset += size
|
||||
|
||||
# Dictionary form of UsageInfo
|
||||
usage = dict(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return body, items, usage
|
||||
|
||||
|
||||
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
|
||||
return [
|
||||
binary2tensor(
|
||||
body[item.start : item.end],
|
||||
item.shape,
|
||||
item.embed_dtype,
|
||||
item.endianness,
|
||||
)
|
||||
for item in sorted(items, key=lambda x: x.index)
|
||||
]
|
||||
Reference in New Issue
Block a user