Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -115,6 +115,7 @@ def init_pooling_state(
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Final
from vllm import PoolingRequestOutput, PromptType
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
from vllm.inputs import ProcessorInputs, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer
class PoolingIOProcessor:
def __init__(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
):
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self.model_config = model_config
self.renderer = renderer
self.chat_template = chat_template_config.chat_template
self.chat_template_content_format: Final = (
chat_template_config.chat_template_content_format
)
self.trust_request_chat_template = (
chat_template_config.trust_request_chat_template
)
def pre_process_online(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_online_async(self, *args, **kwargs):
return self.pre_process_online(*args, **kwargs)
def pre_process_offline(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_offline_async(self, *args, **kwargs):
return self.pre_process_offline(*args, **kwargs)
def post_process(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return outputs
async def post_process_async(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return self.post_process(outputs)
def create_pooling_params(self, request):
return request.to_pooling_params()
def _preprocess_completion_online(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
def _preprocess_chat_online(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
(conversation,), (engine_prompt,) = renderer.render_chat(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
return conversation, [engine_prompt]
def _preprocess_completion_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
prompts = prompt_to_seq(prompts)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
)
def _validate_chat_template(
self,
request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None,
trust_request_chat_template: bool,
):
if not trust_request_chat_template and (
request_chat_template is not None
or (
chat_template_kwargs
and chat_template_kwargs.get("chat_template") is not None
)
):
raise ValueError(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template."
)
return None

View File

@@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
description="Whether to use activation for the pooler outputs. "
"`None` uses the pooler's default, which is `True` in most cases.",
)
normalize: bool | None = Field(
default=None,
description="Deprecated; please pass `use_activation` instead",
)
# --8<-- [end:embed-extra-params]

View File

@@ -0,0 +1,378 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import ClassVar, Generic, TypeVar
from fastapi import Request
from pydantic import ConfigDict
from starlette.datastructures import Headers
from starlette.responses import JSONResponse
from vllm import (
PoolingParams,
PoolingRequestOutput,
PromptType,
SamplingParams,
envs,
)
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateConfig,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
from vllm.inputs import ProcessorInputs
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import BeamSearchParams
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators
from ...utils import create_error_response
from .io_processor import PoolingIOProcessor
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
request: PoolingRequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class PoolingServing:
request_id_prefix: ClassVar[str]
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.models = models
self.model_config = models.model_config
self.max_model_len = self.model_config.max_model_len
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.chat_template_config = ChatTemplateConfig(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template,
)
self.io_processor = self.init_io_processor(
model_config=models.model_config,
renderer=models.renderer,
chat_template_config=self.chat_template_config,
)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request,
) -> JSONResponse:
try:
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
)
await self._check_model(request)
ctx = PoolingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
self._validate_request(ctx)
self._maybe_get_adapters(ctx)
await self._preprocess(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
response = await self._build_response(ctx)
return JSONResponse(content=response.model_dump())
except Exception as e:
error_response = create_error_response(e)
return JSONResponse(
content=error_response.model_dump(),
status_code=error_response.error.code,
)
async def _preprocess(
self,
ctx: PoolingServeContext,
):
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
ctx.request
)
async def _prepare_generators(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self.io_processor.create_pooling_params(ctx.request)
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
async def _collect_batch(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
if ctx.result_generator is None:
raise ValueError("Result generator not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
raise ValueError("Failed to generate results for all prompts")
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
async def _build_response(
self,
ctx: PoolingServeContext,
) -> AnyPoolingResponse:
raise NotImplementedError
@staticmethod
def _base_request_id(
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
if raw_request is not None and (
(req_id := raw_request.headers.get("X-Request-Id")) is not None
):
return req_id
return random_uuid() if default is None else default
def _is_model_supported(self, model_name: str | None) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
async def _check_model(
self,
request: AnyPoolingRequest,
) -> ErrorResponse | None:
if self._is_model_supported(request.model):
return None
if request.model in self.models.lora_requests:
return None
if (
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
and request.model
and (load_result := await self.models.resolve_lora(request.model))
):
if isinstance(load_result, LoRARequest):
return None
if (
isinstance(load_result, ErrorResponse)
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
):
raise ValueError(load_result.error.message)
return None
def _validate_request(self, ctx: PoolingServeContext) -> None:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
if (
truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.max_model_len
):
raise ValueError(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size."
)
return None
async def _get_trace_headers(
self,
headers: Headers,
) -> Mapping[str, str] | None:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
def _maybe_get_adapters(
self,
ctx: PoolingServeContext,
supports_default_mm_loras: bool = False,
):
request = ctx.request
if request.model in self.models.lora_requests:
ctx.lora_request = self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
ctx.lora_request = default_mm_lora
if self._is_model_supported(request.model):
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_active_default_mm_loras(
self, request: AnyPoolingRequest
) -> LoRARequest | None:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
messages = request.messages
if messages is None or isinstance(messages, (str, bytes)):
return message_types
for message in messages:
if (
isinstance(message, dict)
and "content" in message
and isinstance(message["content"], list)
):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
def _log_inputs(
self,
request_id: str,
inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
components = extract_prompt_components(self.model_config, inputs)
self.request_logger.log_inputs(
request_id,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)

View File

@@ -3,16 +3,17 @@
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.entrypoints.utils import (
create_error_response,
load_aware_call,
with_cancellation,
)
router = APIRouter()
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request):
async def create_classify(
request: ClassificationRequest, raw_request: Request
) -> JSONResponse:
handler = classify(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
error_response = create_error_response(
message="The model does not support Classification API"
)
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
content=error_response.model_dump(),
status_code=error_response.error.code,
)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
return await handler(request, raw_request)

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from vllm import PromptType
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
)
from vllm.inputs import ProcessorInputs
from vllm.renderers.inputs import TokPrompt
class ClassifyIOProcessor(PoolingIOProcessor):
def pre_process_online(
self, request: ClassificationCompletionRequest | ClassificationChatRequest
) -> list[TokPrompt] | None:
if isinstance(request, ClassificationChatRequest):
self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
_, engine_prompts = self._preprocess_chat_online(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, ClassificationCompletionRequest):
engine_prompts = self._preprocess_completion_online(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError("Invalid classification request type")
return engine_prompts
def pre_process_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs
)

View File

@@ -40,7 +40,6 @@ class ClassificationCompletionRequest(
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -63,7 +62,6 @@ class ClassificationChatRequest(
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)

View File

@@ -1,116 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, TypeAlias
from typing import TypeAlias
import jinja2
import numpy as np
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
from vllm import ClassificationOutput
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
from vllm.logger import init_logger
from vllm.renderers import BaseRenderer
from .io_processor import ClassifyIOProcessor
from .protocol import (
ClassificationData,
ClassificationRequest,
ClassificationResponse,
)
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
logger = init_logger(__name__)
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
class ServingClassification(PoolingServing):
request_id_prefix = "classify"
def __init__(
def init_io_processor(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> ClassifyIOProcessor:
return ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
async def _build_response(
self,
ctx: ClassificationServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
) -> ClassificationResponse:
final_res_batch_checked = await self.io_processor.post_process_async(
ctx.final_res_batch
)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
ctx.request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, ClassificationCompletionRequest):
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = ctx.final_res_batch
items: list[ClassificationData] = []
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
data=items,
usage=usage,
)
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> ClassificationResponse | ErrorResponse:
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]

View File

@@ -14,12 +14,9 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
def _get_max_total_output_tokens(
model_config: ModelConfig,
@@ -60,18 +57,10 @@ class EmbeddingCompletionRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@@ -97,18 +86,10 @@ class EmbeddingChatRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask
def init_pooling_io_processors(
supported_tasks: tuple[SupportedTask, ...],
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> dict[str, PoolingIOProcessor]:
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.io_processor import (
ClassifyIOProcessor,
)
pooling_io_processors["classify"] = ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
return pooling_io_processors

View File

@@ -16,13 +16,10 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
logger = init_logger(__name__)
class PoolingCompletionRequest(
PoolingBasicRequestMixin,
@@ -45,16 +42,8 @@ class PoolingCompletionRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
)
@@ -78,16 +67,8 @@ class PoolingChatRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
)

View File

@@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)

View File

@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
parse_score_data_single,
validate_score_input,
@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
use_gpu_for_pooling_score: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
log_error_stack=log_error_stack,
)
self.score_template = score_template
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
@@ -311,19 +313,18 @@ class ServingScores(OpenAIServing):
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
maxsim_scores = compute_maxsim_scores(
[emb.outputs.data for emb in emb_data_1],
[emb.outputs.data for emb in emb_data_2],
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
)
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from typing import Any, TypeAlias, cast
import torch
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput
from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
@@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
return use_gpu_for_pooling_score and not current_platform.is_cpu()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000,
use_gpu_for_pooling_score: bool = False,
) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device(
current_platform.device_type
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
else "cpu"
)
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
dim = int(batch_q[0].shape[1])
dtype = batch_q[0].dtype
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=dtype, device=compute_device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=dtype, device=compute_device
)
q_mask = torch.zeros(
(batch_size, max_q), dtype=torch.bool, device=compute_device
)
d_mask = torch.zeros(
(batch_size, max_d), dtype=torch.bool, device=compute_device
)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0)
batch_scores = max_per_query.sum(dim=-1).to("cpu")
scores.extend(batch_scores.unbind(0))
start = end
return [cast(torch.Tensor, score) for score in scores]
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
)
PoolingChatLikeRequest: TypeAlias = (
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)
AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
)
AnyPoolingResponse: TypeAlias = (
ClassificationResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| PoolingResponse
| ScoreResponse
)