Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -115,6 +115,7 @@ def init_pooling_state(
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
|
||||
)
|
||||
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
|
||||
else None
|
||||
|
||||
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Final
|
||||
|
||||
from vllm import PoolingRequestOutput, PromptType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
|
||||
from vllm.inputs import ProcessorInputs, SingletonPrompt
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
|
||||
class PoolingIOProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
):
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.model_config = model_config
|
||||
self.renderer = renderer
|
||||
|
||||
self.chat_template = chat_template_config.chat_template
|
||||
self.chat_template_content_format: Final = (
|
||||
chat_template_config.chat_template_content_format
|
||||
)
|
||||
self.trust_request_chat_template = (
|
||||
chat_template_config.trust_request_chat_template
|
||||
)
|
||||
|
||||
def pre_process_online(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_online_async(self, *args, **kwargs):
|
||||
return self.pre_process_online(*args, **kwargs)
|
||||
|
||||
def pre_process_offline(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_offline_async(self, *args, **kwargs):
|
||||
return self.pre_process_offline(*args, **kwargs)
|
||||
|
||||
def post_process(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return outputs
|
||||
|
||||
async def post_process_async(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return self.post_process(outputs)
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params()
|
||||
|
||||
def _preprocess_completion_online(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = request.build_tok_params(model_config)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
def _preprocess_chat_online(
|
||||
self,
|
||||
request: RendererChatRequest,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
)
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
|
||||
(conversation,), (engine_prompt,) = renderer.render_chat(
|
||||
[messages],
|
||||
chat_params,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _preprocess_completion_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = prompt_to_seq(prompts)
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
trust_request_chat_template: bool,
|
||||
):
|
||||
if not trust_request_chat_template and (
|
||||
request_chat_template is not None
|
||||
or (
|
||||
chat_template_kwargs
|
||||
and chat_template_kwargs.get("chat_template") is not None
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template."
|
||||
)
|
||||
return None
|
||||
@@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
|
||||
description="Whether to use activation for the pooler outputs. "
|
||||
"`None` uses the pooler's default, which is `True` in most cases.",
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated; please pass `use_activation` instead",
|
||||
)
|
||||
# --8<-- [end:embed-extra-params]
|
||||
|
||||
|
||||
|
||||
378
vllm/entrypoints/pooling/base/serving.py
Normal file
378
vllm/entrypoints/pooling/base/serving.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import ConfigDict
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from vllm import (
|
||||
PoolingParams,
|
||||
PoolingRequestOutput,
|
||||
PromptType,
|
||||
SamplingParams,
|
||||
envs,
|
||||
)
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
from ...utils import create_error_response
|
||||
from .io_processor import PoolingIOProcessor
|
||||
|
||||
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
request: PoolingRequestT
|
||||
raw_request: Request | None = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[ProcessorInputs] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
)
|
||||
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PoolingServing:
|
||||
request_id_prefix: ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.engine_client = engine_client
|
||||
self.models = models
|
||||
self.model_config = models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.log_error_stack = log_error_stack
|
||||
self.chat_template_config = ChatTemplateConfig(
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
self.io_processor = self.init_io_processor(
|
||||
model_config=models.model_config,
|
||||
renderer=models.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request,
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
)
|
||||
|
||||
await self._check_model(request)
|
||||
|
||||
ctx = PoolingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
self._validate_request(ctx)
|
||||
self._maybe_get_adapters(ctx)
|
||||
await self._preprocess(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
response = await self._build_response(ctx)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except Exception as e:
|
||||
error_response = create_error_response(e)
|
||||
return JSONResponse(
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
|
||||
ctx.request
|
||||
)
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
if ctx.result_generator is None:
|
||||
raise ValueError("Result generator not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
raise ValueError("Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> AnyPoolingResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(
|
||||
raw_request: Request | None, default: str | None = None
|
||||
) -> str | None:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
if raw_request is not None and (
|
||||
(req_id := raw_request.headers.get("X-Request-Id")) is not None
|
||||
):
|
||||
return req_id
|
||||
|
||||
return random_uuid() if default is None else default
|
||||
|
||||
def _is_model_supported(self, model_name: str | None) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
) -> ErrorResponse | None:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in self.models.lora_requests:
|
||||
return None
|
||||
if (
|
||||
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
|
||||
and request.model
|
||||
and (load_result := await self.models.resolve_lora(request.model))
|
||||
):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if (
|
||||
isinstance(load_result, ErrorResponse)
|
||||
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
|
||||
):
|
||||
raise ValueError(load_result.error.message)
|
||||
return None
|
||||
|
||||
def _validate_request(self, ctx: PoolingServeContext) -> None:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
|
||||
|
||||
if (
|
||||
truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len
|
||||
):
|
||||
raise ValueError(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Mapping[str, str] | None:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
supports_default_mm_loras: bool = False,
|
||||
):
|
||||
request = ctx.request
|
||||
if request.model in self.models.lora_requests:
|
||||
ctx.lora_request = self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
ctx.lora_request = default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyPoolingRequest
|
||||
) -> LoRARequest | None:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
messages = request.messages
|
||||
if messages is None or isinstance(messages, (str, bytes)):
|
||||
return message_types
|
||||
|
||||
for message in messages:
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and "content" in message
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
components = extract_prompt_components(self.model_config, inputs)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
components.text,
|
||||
components.token_ids,
|
||||
components.embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -3,16 +3,17 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.entrypoints.utils import (
|
||||
create_error_response,
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
async def create_classify(
|
||||
request: ClassificationRequest, raw_request: Request
|
||||
) -> JSONResponse:
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
error_response = create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm import PromptType
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
)
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
|
||||
|
||||
class ClassifyIOProcessor(PoolingIOProcessor):
|
||||
def pre_process_online(
|
||||
self, request: ClassificationCompletionRequest | ClassificationChatRequest
|
||||
) -> list[TokPrompt] | None:
|
||||
if isinstance(request, ClassificationChatRequest):
|
||||
self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
_, engine_prompts = self._preprocess_chat_online(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, ClassificationCompletionRequest):
|
||||
engine_prompts = self._preprocess_completion_online(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid classification request type")
|
||||
return engine_prompts
|
||||
|
||||
def pre_process_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=prompts, tokenization_kwargs=tokenization_kwargs
|
||||
)
|
||||
@@ -40,7 +40,6 @@ class ClassificationCompletionRequest(
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
@@ -63,7 +62,6 @@ class ClassificationChatRequest(
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,116 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
from vllm import ClassificationOutput
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ClassifyIOProcessor
|
||||
from .protocol import (
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
class ServingClassification(PoolingServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
def init_io_processor(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def _preprocess(
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
) -> ClassificationResponse:
|
||||
final_res_batch_checked = await self.io_processor.post_process_async(
|
||||
ctx.final_res_batch
|
||||
)
|
||||
|
||||
if isinstance(ctx.request, ClassificationChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
id2label = getattr(self.model_config.hf_config, "id2label", {})
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
|
||||
@@ -14,12 +14,9 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
EmbedRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
@@ -60,18 +57,10 @@ class EmbeddingCompletionRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.use_activation,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -97,18 +86,10 @@ class EmbeddingChatRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.use_activation,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
31
vllm/entrypoints/pooling/io_processor_factories.py
Normal file
31
vllm/entrypoints/pooling/io_processor_factories.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tasks import SupportedTask
|
||||
|
||||
|
||||
def init_pooling_io_processors(
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> dict[str, PoolingIOProcessor]:
|
||||
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.io_processor import (
|
||||
ClassifyIOProcessor,
|
||||
)
|
||||
|
||||
pooling_io_processors["classify"] = ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
return pooling_io_processors
|
||||
@@ -16,13 +16,10 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
EncodingRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PoolingCompletionRequest(
|
||||
PoolingBasicRequestMixin,
|
||||
@@ -45,16 +42,8 @@ class PoolingCompletionRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task=self.task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
@@ -78,16 +67,8 @@ class PoolingChatRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task=self.task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||
return PoolingParams(
|
||||
task=task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
@@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||
return PoolingParams(
|
||||
task=task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
compute_maxsim_scores,
|
||||
get_score_prompt,
|
||||
parse_score_data_single,
|
||||
validate_score_input,
|
||||
@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
use_gpu_for_pooling_score: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.score_template = score_template
|
||||
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@@ -311,19 +313,18 @@ class ServingScores(OpenAIServing):
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
maxsim_scores = compute_maxsim_scores(
|
||||
[emb.outputs.data for emb in emb_data_1],
|
||||
[emb.outputs.data for emb in emb_data_2],
|
||||
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
|
||||
)
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import torch
|
||||
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
@@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
|
||||
return use_gpu_for_pooling_score and not current_platform.is_cpu()
|
||||
|
||||
|
||||
def compute_maxsim_scores(
|
||||
q_embs: Sequence[torch.Tensor],
|
||||
d_embs: Sequence[torch.Tensor],
|
||||
max_batch_size: int = 16,
|
||||
max_score_matrix_elements: int = 16_000_000,
|
||||
use_gpu_for_pooling_score: bool = False,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Compute ColBERT MaxSim scores in padded mini-batches."""
|
||||
if len(q_embs) != len(d_embs):
|
||||
raise ValueError("q_embs and d_embs must have the same length")
|
||||
|
||||
num_pairs = len(q_embs)
|
||||
if num_pairs == 0:
|
||||
return []
|
||||
|
||||
for q_emb, d_emb in zip(q_embs, d_embs):
|
||||
if q_emb.ndim != 2 or d_emb.ndim != 2:
|
||||
raise ValueError("Each embedding tensor must be 2-D")
|
||||
if q_emb.shape[1] != d_emb.shape[1]:
|
||||
raise ValueError("Query and document embeddings must have same dim")
|
||||
|
||||
compute_device = torch.device(
|
||||
current_platform.device_type
|
||||
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
|
||||
else "cpu"
|
||||
)
|
||||
scores: list[torch.Tensor] = []
|
||||
start = 0
|
||||
while start < num_pairs:
|
||||
end = min(start + max_batch_size, num_pairs)
|
||||
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
|
||||
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
|
||||
|
||||
# keep score matrix bounded to avoid oversized allocations.
|
||||
while (
|
||||
end - start > 1
|
||||
and (end - start) * max_q * max_d > max_score_matrix_elements
|
||||
):
|
||||
end -= 1
|
||||
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
|
||||
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
|
||||
|
||||
batch_q = q_embs[start:end]
|
||||
batch_d = d_embs[start:end]
|
||||
batch_size = end - start
|
||||
dim = int(batch_q[0].shape[1])
|
||||
dtype = batch_q[0].dtype
|
||||
|
||||
q_batch = torch.zeros(
|
||||
(batch_size, max_q, dim), dtype=dtype, device=compute_device
|
||||
)
|
||||
d_batch = torch.zeros(
|
||||
(batch_size, max_d, dim), dtype=dtype, device=compute_device
|
||||
)
|
||||
q_mask = torch.zeros(
|
||||
(batch_size, max_q), dtype=torch.bool, device=compute_device
|
||||
)
|
||||
d_mask = torch.zeros(
|
||||
(batch_size, max_d), dtype=torch.bool, device=compute_device
|
||||
)
|
||||
|
||||
# copy to padded tensors
|
||||
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
|
||||
q_len = int(q_emb.shape[0])
|
||||
d_len = int(d_emb.shape[0])
|
||||
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
|
||||
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
|
||||
q_mask[i, :q_len] = True
|
||||
d_mask[i, :d_len] = True
|
||||
|
||||
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
|
||||
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
|
||||
max_per_query = token_scores.amax(dim=-1)
|
||||
max_per_query.masked_fill_(~q_mask, 0)
|
||||
batch_scores = max_per_query.sum(dim=-1).to("cpu")
|
||||
scores.extend(batch_scores.unbind(0))
|
||||
start = end
|
||||
|
||||
return [cast(torch.Tensor, score) for score in scores]
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
51
vllm/entrypoints/pooling/typing.py
Normal file
51
vllm/entrypoints/pooling/typing.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
|
||||
PoolingCompletionLikeRequest: TypeAlias = (
|
||||
EmbeddingCompletionRequest
|
||||
| ClassificationCompletionRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
)
|
||||
|
||||
PoolingChatLikeRequest: TypeAlias = (
|
||||
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
|
||||
)
|
||||
|
||||
AnyPoolingRequest: TypeAlias = (
|
||||
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
|
||||
)
|
||||
|
||||
AnyPoolingResponse: TypeAlias = (
|
||||
ClassificationResponse
|
||||
| EmbeddingResponse
|
||||
| EmbeddingBytesResponse
|
||||
| PoolingResponse
|
||||
| ScoreResponse
|
||||
)
|
||||
Reference in New Issue
Block a user