Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Final
|
||||
|
||||
from vllm import PoolingRequestOutput, PromptType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
|
||||
from vllm.inputs import ProcessorInputs, SingletonPrompt
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
|
||||
class PoolingIOProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
):
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.model_config = model_config
|
||||
self.renderer = renderer
|
||||
|
||||
self.chat_template = chat_template_config.chat_template
|
||||
self.chat_template_content_format: Final = (
|
||||
chat_template_config.chat_template_content_format
|
||||
)
|
||||
self.trust_request_chat_template = (
|
||||
chat_template_config.trust_request_chat_template
|
||||
)
|
||||
|
||||
def pre_process_online(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_online_async(self, *args, **kwargs):
|
||||
return self.pre_process_online(*args, **kwargs)
|
||||
|
||||
def pre_process_offline(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_offline_async(self, *args, **kwargs):
|
||||
return self.pre_process_offline(*args, **kwargs)
|
||||
|
||||
def post_process(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return outputs
|
||||
|
||||
async def post_process_async(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return self.post_process(outputs)
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params()
|
||||
|
||||
def _preprocess_completion_online(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = request.build_tok_params(model_config)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
def _preprocess_chat_online(
|
||||
self,
|
||||
request: RendererChatRequest,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
)
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
|
||||
(conversation,), (engine_prompt,) = renderer.render_chat(
|
||||
[messages],
|
||||
chat_params,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _preprocess_completion_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = prompt_to_seq(prompts)
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
trust_request_chat_template: bool,
|
||||
):
|
||||
if not trust_request_chat_template and (
|
||||
request_chat_template is not None
|
||||
or (
|
||||
chat_template_kwargs
|
||||
and chat_template_kwargs.get("chat_template") is not None
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template."
|
||||
)
|
||||
return None
|
||||
@@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
|
||||
description="Whether to use activation for the pooler outputs. "
|
||||
"`None` uses the pooler's default, which is `True` in most cases.",
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated; please pass `use_activation` instead",
|
||||
)
|
||||
# --8<-- [end:embed-extra-params]
|
||||
|
||||
|
||||
|
||||
378
vllm/entrypoints/pooling/base/serving.py
Normal file
378
vllm/entrypoints/pooling/base/serving.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import ConfigDict
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from vllm import (
|
||||
PoolingParams,
|
||||
PoolingRequestOutput,
|
||||
PromptType,
|
||||
SamplingParams,
|
||||
envs,
|
||||
)
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
from ...utils import create_error_response
|
||||
from .io_processor import PoolingIOProcessor
|
||||
|
||||
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
request: PoolingRequestT
|
||||
raw_request: Request | None = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[ProcessorInputs] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
)
|
||||
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PoolingServing:
|
||||
request_id_prefix: ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.engine_client = engine_client
|
||||
self.models = models
|
||||
self.model_config = models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.log_error_stack = log_error_stack
|
||||
self.chat_template_config = ChatTemplateConfig(
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
self.io_processor = self.init_io_processor(
|
||||
model_config=models.model_config,
|
||||
renderer=models.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request,
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
)
|
||||
|
||||
await self._check_model(request)
|
||||
|
||||
ctx = PoolingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
self._validate_request(ctx)
|
||||
self._maybe_get_adapters(ctx)
|
||||
await self._preprocess(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
response = await self._build_response(ctx)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except Exception as e:
|
||||
error_response = create_error_response(e)
|
||||
return JSONResponse(
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
|
||||
ctx.request
|
||||
)
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
if ctx.result_generator is None:
|
||||
raise ValueError("Result generator not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
raise ValueError("Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> AnyPoolingResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(
|
||||
raw_request: Request | None, default: str | None = None
|
||||
) -> str | None:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
if raw_request is not None and (
|
||||
(req_id := raw_request.headers.get("X-Request-Id")) is not None
|
||||
):
|
||||
return req_id
|
||||
|
||||
return random_uuid() if default is None else default
|
||||
|
||||
def _is_model_supported(self, model_name: str | None) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
) -> ErrorResponse | None:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in self.models.lora_requests:
|
||||
return None
|
||||
if (
|
||||
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
|
||||
and request.model
|
||||
and (load_result := await self.models.resolve_lora(request.model))
|
||||
):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if (
|
||||
isinstance(load_result, ErrorResponse)
|
||||
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
|
||||
):
|
||||
raise ValueError(load_result.error.message)
|
||||
return None
|
||||
|
||||
def _validate_request(self, ctx: PoolingServeContext) -> None:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
|
||||
|
||||
if (
|
||||
truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len
|
||||
):
|
||||
raise ValueError(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Mapping[str, str] | None:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
supports_default_mm_loras: bool = False,
|
||||
):
|
||||
request = ctx.request
|
||||
if request.model in self.models.lora_requests:
|
||||
ctx.lora_request = self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
ctx.lora_request = default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyPoolingRequest
|
||||
) -> LoRARequest | None:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
messages = request.messages
|
||||
if messages is None or isinstance(messages, (str, bytes)):
|
||||
return message_types
|
||||
|
||||
for message in messages:
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and "content" in message
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
components = extract_prompt_components(self.model_config, inputs)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
components.text,
|
||||
components.token_ids,
|
||||
components.embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
Reference in New Issue
Block a user