379 lines
13 KiB
Python
379 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import time
|
|
from collections.abc import AsyncGenerator, Mapping
|
|
from dataclasses import dataclass, field
|
|
from http import HTTPStatus
|
|
from typing import ClassVar, Generic, TypeVar
|
|
|
|
from fastapi import Request
|
|
from pydantic import ConfigDict
|
|
from starlette.datastructures import Headers
|
|
from starlette.responses import JSONResponse
|
|
|
|
from vllm import (
|
|
PoolingParams,
|
|
PoolingRequestOutput,
|
|
PromptType,
|
|
SamplingParams,
|
|
envs,
|
|
)
|
|
from vllm.config import ModelConfig
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.entrypoints.chat_utils import (
|
|
ChatTemplateConfig,
|
|
ChatTemplateContentFormatOption,
|
|
)
|
|
from vllm.entrypoints.logger import RequestLogger
|
|
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
|
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
|
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
|
|
from vllm.inputs import ProcessorInputs
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.renderers import BaseRenderer
|
|
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
|
from vllm.sampling_params import BeamSearchParams
|
|
from vllm.tracing import (
|
|
contains_trace_headers,
|
|
extract_trace_headers,
|
|
log_tracing_disabled_warning,
|
|
)
|
|
from vllm.utils import random_uuid
|
|
from vllm.utils.async_utils import merge_async_iterators
|
|
|
|
from ...utils import create_error_response
|
|
from .io_processor import PoolingIOProcessor
|
|
|
|
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class PoolingServeContext(Generic[PoolingRequestT]):
|
|
request: PoolingRequestT
|
|
raw_request: Request | None = None
|
|
model_name: str
|
|
request_id: str
|
|
created_time: int = field(default_factory=lambda: int(time.time()))
|
|
lora_request: LoRARequest | None = None
|
|
engine_prompts: list[ProcessorInputs] | None = None
|
|
|
|
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
|
None
|
|
)
|
|
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
|
|
class PoolingServing:
|
|
request_id_prefix: ClassVar[str]
|
|
|
|
def __init__(
|
|
self,
|
|
engine_client: EngineClient,
|
|
models: OpenAIServingModels,
|
|
*,
|
|
request_logger: RequestLogger | None,
|
|
chat_template: str | None = None,
|
|
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
|
trust_request_chat_template: bool = False,
|
|
return_tokens_as_token_ids: bool = False,
|
|
log_error_stack: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.engine_client = engine_client
|
|
self.models = models
|
|
self.model_config = models.model_config
|
|
self.max_model_len = self.model_config.max_model_len
|
|
self.request_logger = request_logger
|
|
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
|
self.log_error_stack = log_error_stack
|
|
self.chat_template_config = ChatTemplateConfig(
|
|
chat_template=chat_template,
|
|
chat_template_content_format=chat_template_content_format,
|
|
trust_request_chat_template=trust_request_chat_template,
|
|
)
|
|
self.io_processor = self.init_io_processor(
|
|
model_config=models.model_config,
|
|
renderer=models.renderer,
|
|
chat_template_config=self.chat_template_config,
|
|
)
|
|
|
|
def init_io_processor(
|
|
self,
|
|
model_config: ModelConfig,
|
|
renderer: BaseRenderer,
|
|
chat_template_config: ChatTemplateConfig,
|
|
) -> PoolingIOProcessor:
|
|
raise NotImplementedError
|
|
|
|
async def __call__(
|
|
self,
|
|
request: AnyPoolingRequest,
|
|
raw_request: Request,
|
|
) -> JSONResponse:
|
|
try:
|
|
model_name = self.models.model_name()
|
|
request_id = (
|
|
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
|
)
|
|
|
|
await self._check_model(request)
|
|
|
|
ctx = PoolingServeContext(
|
|
request=request,
|
|
raw_request=raw_request,
|
|
model_name=model_name,
|
|
request_id=request_id,
|
|
)
|
|
|
|
self._validate_request(ctx)
|
|
self._maybe_get_adapters(ctx)
|
|
await self._preprocess(ctx)
|
|
await self._prepare_generators(ctx)
|
|
await self._collect_batch(ctx)
|
|
response = await self._build_response(ctx)
|
|
return JSONResponse(content=response.model_dump())
|
|
except Exception as e:
|
|
error_response = create_error_response(e)
|
|
return JSONResponse(
|
|
content=error_response.model_dump(),
|
|
status_code=error_response.error.code,
|
|
)
|
|
|
|
async def _preprocess(
|
|
self,
|
|
ctx: PoolingServeContext,
|
|
):
|
|
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
|
|
ctx.request
|
|
)
|
|
|
|
async def _prepare_generators(
|
|
self,
|
|
ctx: PoolingServeContext,
|
|
):
|
|
if ctx.engine_prompts is None:
|
|
raise ValueError("Engine prompts not available")
|
|
|
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
|
|
|
trace_headers = (
|
|
None
|
|
if ctx.raw_request is None
|
|
else await self._get_trace_headers(ctx.raw_request.headers)
|
|
)
|
|
|
|
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
|
|
|
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
|
request_id_item = f"{ctx.request_id}-{i}"
|
|
|
|
self._log_inputs(
|
|
request_id_item,
|
|
engine_prompt,
|
|
params=pooling_params,
|
|
lora_request=ctx.lora_request,
|
|
)
|
|
|
|
generator = self.engine_client.encode(
|
|
engine_prompt,
|
|
pooling_params,
|
|
request_id_item,
|
|
lora_request=ctx.lora_request,
|
|
trace_headers=trace_headers,
|
|
priority=getattr(ctx.request, "priority", 0),
|
|
)
|
|
|
|
generators.append(generator)
|
|
|
|
ctx.result_generator = merge_async_iterators(*generators)
|
|
|
|
async def _collect_batch(
|
|
self,
|
|
ctx: PoolingServeContext,
|
|
):
|
|
if ctx.engine_prompts is None:
|
|
raise ValueError("Engine prompts not available")
|
|
|
|
if ctx.result_generator is None:
|
|
raise ValueError("Result generator not available")
|
|
|
|
num_prompts = len(ctx.engine_prompts)
|
|
final_res_batch: list[PoolingRequestOutput | None]
|
|
final_res_batch = [None] * num_prompts
|
|
|
|
async for i, res in ctx.result_generator:
|
|
final_res_batch[i] = res
|
|
|
|
if None in final_res_batch:
|
|
raise ValueError("Failed to generate results for all prompts")
|
|
|
|
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
|
|
|
async def _build_response(
|
|
self,
|
|
ctx: PoolingServeContext,
|
|
) -> AnyPoolingResponse:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def _base_request_id(
|
|
raw_request: Request | None, default: str | None = None
|
|
) -> str | None:
|
|
"""Pulls the request id to use from a header, if provided"""
|
|
if raw_request is not None and (
|
|
(req_id := raw_request.headers.get("X-Request-Id")) is not None
|
|
):
|
|
return req_id
|
|
|
|
return random_uuid() if default is None else default
|
|
|
|
def _is_model_supported(self, model_name: str | None) -> bool:
|
|
if not model_name:
|
|
return True
|
|
return self.models.is_base_model(model_name)
|
|
|
|
async def _check_model(
|
|
self,
|
|
request: AnyPoolingRequest,
|
|
) -> ErrorResponse | None:
|
|
if self._is_model_supported(request.model):
|
|
return None
|
|
if request.model in self.models.lora_requests:
|
|
return None
|
|
if (
|
|
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
|
|
and request.model
|
|
and (load_result := await self.models.resolve_lora(request.model))
|
|
):
|
|
if isinstance(load_result, LoRARequest):
|
|
return None
|
|
if (
|
|
isinstance(load_result, ErrorResponse)
|
|
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
|
|
):
|
|
raise ValueError(load_result.error.message)
|
|
return None
|
|
|
|
def _validate_request(self, ctx: PoolingServeContext) -> None:
|
|
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
|
|
|
|
if (
|
|
truncate_prompt_tokens is not None
|
|
and truncate_prompt_tokens > self.max_model_len
|
|
):
|
|
raise ValueError(
|
|
"truncate_prompt_tokens value is "
|
|
"greater than max_model_len."
|
|
" Please, select a smaller truncation size."
|
|
)
|
|
return None
|
|
|
|
async def _get_trace_headers(
|
|
self,
|
|
headers: Headers,
|
|
) -> Mapping[str, str] | None:
|
|
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
|
|
|
if is_tracing_enabled:
|
|
return extract_trace_headers(headers)
|
|
|
|
if contains_trace_headers(headers):
|
|
log_tracing_disabled_warning()
|
|
|
|
return None
|
|
|
|
def _maybe_get_adapters(
|
|
self,
|
|
ctx: PoolingServeContext,
|
|
supports_default_mm_loras: bool = False,
|
|
):
|
|
request = ctx.request
|
|
if request.model in self.models.lora_requests:
|
|
ctx.lora_request = self.models.lora_requests[request.model]
|
|
|
|
# Currently only support default modality specific loras
|
|
# if we have exactly one lora matched on the request.
|
|
if supports_default_mm_loras:
|
|
default_mm_lora = self._get_active_default_mm_loras(request)
|
|
if default_mm_lora is not None:
|
|
ctx.lora_request = default_mm_lora
|
|
|
|
if self._is_model_supported(request.model):
|
|
return None
|
|
|
|
# if _check_model has been called earlier, this will be unreachable
|
|
raise ValueError(f"The model `{request.model}` does not exist.")
|
|
|
|
def _get_active_default_mm_loras(
|
|
self, request: AnyPoolingRequest
|
|
) -> LoRARequest | None:
|
|
"""Determine if there are any active default multimodal loras."""
|
|
# TODO: Currently this is only enabled for chat completions
|
|
# to be better aligned with only being enabled for .generate
|
|
# when run offline. It would be nice to support additional
|
|
# tasks types in the future.
|
|
message_types = self._get_message_types(request)
|
|
default_mm_loras = set()
|
|
|
|
for lora in self.models.lora_requests.values():
|
|
# Best effort match for default multimodal lora adapters;
|
|
# There is probably a better way to do this, but currently
|
|
# this matches against the set of 'types' in any content lists
|
|
# up until '_', e.g., to match audio_url -> audio
|
|
if lora.lora_name in message_types:
|
|
default_mm_loras.add(lora)
|
|
|
|
# Currently only support default modality specific loras if
|
|
# we have exactly one lora matched on the request.
|
|
if len(default_mm_loras) == 1:
|
|
return default_mm_loras.pop()
|
|
return None
|
|
|
|
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
|
|
"""Retrieve the set of types from message content dicts up
|
|
until `_`; we use this to match potential multimodal data
|
|
with default per modality loras.
|
|
"""
|
|
message_types: set[str] = set()
|
|
|
|
if not hasattr(request, "messages"):
|
|
return message_types
|
|
|
|
messages = request.messages
|
|
if messages is None or isinstance(messages, (str, bytes)):
|
|
return message_types
|
|
|
|
for message in messages:
|
|
if (
|
|
isinstance(message, dict)
|
|
and "content" in message
|
|
and isinstance(message["content"], list)
|
|
):
|
|
for content_dict in message["content"]:
|
|
if "type" in content_dict:
|
|
message_types.add(content_dict["type"].split("_")[0])
|
|
return message_types
|
|
|
|
def _log_inputs(
|
|
self,
|
|
request_id: str,
|
|
inputs: PromptType | ProcessorInputs,
|
|
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
|
lora_request: LoRARequest | None,
|
|
) -> None:
|
|
if self.request_logger is None:
|
|
return
|
|
|
|
components = extract_prompt_components(self.model_config, inputs)
|
|
|
|
self.request_logger.log_inputs(
|
|
request_id,
|
|
components.text,
|
|
components.token_ids,
|
|
components.embeds,
|
|
params=params,
|
|
lora_request=lora_request,
|
|
)
|