Files
bi_150-vllm/vllm/v1/engine/input_processor.py

477 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import warnings
from collections.abc import Mapping
from typing import Any, Literal
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonInputs,
)
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
)
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.sampling_params import SamplingParams
from vllm.tasks import GENERATION_TASKS, POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.func_utils import supports_kw
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
class InputProcessor:
def __init__(
self,
vllm_config: VllmConfig,
renderer: BaseRenderer | None = None,
*,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
self.model_config = model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.structured_outputs_config = vllm_config.structured_outputs_config
self.observability_config = vllm_config.observability_config
self.generation_config_fields = model_config.try_get_generation_config()
self.renderer = renderer or renderer_from_config(vllm_config)
self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config)
self.mm_encoder_cache_size = 0
self.skip_prompt_length_check = False
if self.supports_mm_inputs:
mm_budget = MultiModalBudget(vllm_config, mm_registry)
self.mm_encoder_cache_size = mm_budget.encoder_cache_size
self.skip_prompt_length_check = (
mm_budget.processor.info.skip_prompt_length_check
)
mm_budget.reset_cache() # Not used anymore
self.input_preprocessor = InputPreprocessor(
vllm_config,
renderer=renderer,
mm_registry=mm_registry,
)
from vllm.platforms import current_platform
platform_validate_request = current_platform.validate_request
if supports_kw(platform_validate_request, "prompt"):
logger.warning_once(
"The signature of Platform.validate_request has changed from "
"`(cls, prompt, params, processed_inputs) -> None` to "
"`(cls, processed_inputs, params) -> None`. The old signature "
"will no longer be supported starting from v0.18."
)
orig_validate_request = platform_validate_request
def compat_validate_request(
processed_inputs: ProcessorInputs,
params: SamplingParams | PoolingParams,
):
return orig_validate_request(
processed_inputs,
params,
processed_inputs, # type: ignore
) # type: ignore
platform_validate_request = compat_validate_request
self._platform_validate_request = platform_validate_request
@property
def tokenizer(self) -> TokenizerLike | None:
return self.renderer.tokenizer
def get_tokenizer(self) -> TokenizerLike:
return self.renderer.get_tokenizer()
def _validate_params(
self,
params: SamplingParams | PoolingParams,
supported_tasks: tuple[SupportedTask, ...],
) -> None:
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.17. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(params, SamplingParams):
supported_generation_tasks = [
task for task in supported_tasks if task in GENERATION_TASKS
]
if not supported_generation_tasks:
raise ValueError("This model does not support generation")
params.verify(
self.model_config,
self.speculative_config,
self.structured_outputs_config,
self.tokenizer,
)
elif isinstance(params, PoolingParams):
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
]
if not supported_pooling_tasks:
raise ValueError("This model does not support pooling")
if params.task is None:
if "token_embed" in supported_pooling_tasks:
params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks:
params.task = "token_classify"
elif "plugin" in supported_pooling_tasks:
params.task = "plugin"
if params.task not in supported_pooling_tasks:
raise ValueError(
f"Unsupported task: {params.task!r} "
f"Supported tasks: {supported_pooling_tasks}"
)
params.verify(self.model_config)
else:
raise TypeError(
f"params must be either SamplingParams or PoolingParams, "
f"but got {type(params).__name__}"
)
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None:
return
# LoRA request passed in while LoRA is not enabled
if not self.lora_config:
raise ValueError(
f"Got lora_request {lora_request} but LoRA is not enabled!"
)
if self.tokenizer is not None:
logger.warning_once(
"vLLM has deprecated support for supporting different "
"tokenizers for different LoRAs. By default, vLLM uses base "
"model's tokenizer. If you are using a LoRA "
"with its own tokenizer, consider specifying `--tokenizer "
"[lora_path]` to use the LoRA tokenizer."
)
def _get_mm_identifier(
self,
mm_hash: str,
lora_request: LoRARequest | None,
) -> str:
"""
When enable_tower_connector_lora is True, multi-modal embeddings
vary depending on the LoRA request. Therefore, the mm_hash must be
generated based on the LoRA request to prevent incorrect cache hits.
"""
if (
lora_request is None
or self.lora_config is None
or not self.lora_config.enable_tower_connector_lora
):
return mm_hash
return f"{lora_request.lora_name}:{mm_hash}"
@staticmethod
def assign_request_id(request: EngineCoreRequest):
"""Replace the externally supplied request ID with an internal request ID
that adds 8 random characters in order to ensure uniquness.
"""
if request.external_req_id is not None:
raise ValueError(
"The external_req_id field should not be set on EngineCoreRequests"
" passed to vLLM; use the request_id field."
)
request.external_req_id = request.request_id
if envs.VLLM_DISABLE_REQUEST_ID_RANDOMIZATION:
logger.warning_once(
"VLLM_DISABLE_REQUEST_ID_RANDOMIZATION is set and will be "
"removed in a future release. Duplicate externally-provided "
"request IDs may cause failures and/or subtle correctness errors."
)
else:
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
def process_inputs(
self,
request_id: str,
prompt: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams,
supported_tasks: tuple[SupportedTask, ...],
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
self._validate_params(params, supported_tasks)
self._validate_lora(lora_request)
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_local_size = parallel_config.data_parallel_size_local
num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank < num_ranks):
raise ValueError(
f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {num_ranks})."
)
if isinstance(prompt, dict) and "type" in prompt:
if tokenization_kwargs:
logger.warning_once(
"Passing tokenization_kwargs to InputProcessor is deprecated "
"and will be removed in v0.18. You should instead pass "
"them to Renderer.render_cmpl() or Renderer.render_chat()."
)
if arrival_time is None:
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
else:
logger.warning_once(
"Passing raw prompts to InputProcessor is deprecated "
"and will be removed in v0.18. You should instead pass "
"the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
)
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
)
self._platform_validate_request(processed_inputs, params)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy can be conservative for TypedDict unions; normalize access.
if decoder_inputs["type"] == "embeds":
prompt_token_ids = None
prompt_embeds = decoder_inputs["prompt_embeds"]
else:
prompt_token_ids = decoder_inputs["prompt_token_ids"]
prompt_embeds = None
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
seq_len = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds
)
sampling_params.max_tokens = self.model_config.max_model_len - seq_len
sampling_params.update_from_generation_config(
self.generation_config_fields,
self.renderer.get_eos_token_id(),
)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(self.tokenizer)
else:
pooling_params = params.clone()
# Multimodal related.
mm_features: list[MultiModalFeatureSpec] | None = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs["mm_hashes"]
if not all(
isinstance(leaf, str) for leaf in json_iter_leaves(decoder_mm_hashes)
):
raise ValueError(
f"mm_hashes must contain only strings, got: {decoder_mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method."
)
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
mm_features = []
for modality, idx in sorted_mm_idxs:
base_mm_hash = decoder_mm_hashes[modality][idx]
mm_features.append(
MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx],
modality=modality,
identifier=self._get_mm_identifier(
base_mm_hash,
lora_request,
),
mm_position=decoder_mm_positions[modality][idx],
mm_hash=base_mm_hash,
)
)
return EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
prompt_embeds=prompt_embeds,
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
resumable=resumable,
)
def _validate_prompt_len(
self,
prompt_len: int,
prompt_type: Literal["encoder", "decoder"],
):
if self.skip_prompt_length_check:
return
if prompt_len == 0 and prompt_type == "decoder":
raise ValueError(f"The {prompt_type} prompt cannot be empty")
model_config = self.model_config
max_prompt_len = (
model_config.max_model_len
if prompt_type == "decoder"
else self.mm_encoder_cache_size
)
if prompt_len > max_prompt_len:
if self.supports_mm_inputs:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well."
)
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens."
)
raise ValueError(
f"The {prompt_type} prompt (length {prompt_len}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}"
)
elif prompt_len == max_prompt_len and model_config.runner_type == "generate":
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens (prompt + requested output tokens)."
)
raise ValueError(
f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
f"requested output tokens (at least 1) is longer than the maximum "
f"model length of {max_prompt_len}. {suggestion}"
)
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
prompt_type: Literal["encoder", "decoder"],
) -> None:
model_config = self.model_config
tokenizer = self.tokenizer
prompt_ids = (
None
if prompt_inputs["type"] == "embeds"
else prompt_inputs["prompt_token_ids"]
)
prompt_embeds = (
prompt_inputs["prompt_embeds"]
if prompt_inputs["type"] == "embeds"
else None
)
prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
self._validate_prompt_len(prompt_len, prompt_type)
if prompt_inputs["type"] == "multimodal":
decoder_mm_positions = prompt_inputs["mm_placeholders"]
for modality, mm_positions in decoder_mm_positions.items():
for mm_position in mm_positions:
embed_length = mm_position.get_num_embeds()
if embed_length > self.mm_encoder_cache_size:
raise ValueError(
f"The {prompt_type} prompt contains a(n) {modality} item "
f"with length {embed_length}, which exceeds the "
f"pre-allocated encoder cache size "
f"{self.mm_encoder_cache_size}. Please reduce the input "
f"size or increase the encoder cache size "
f"by setting --limit-mm-per-prompt at startup."
)
if prompt_ids and tokenizer is not None:
max_input_id = max(prompt_ids, default=0)
# NOTE: tokenizer.max_token_id is the tokenizers vocab size while
# self.model_config.get_vocab_size() is the models vocab size.
# For Qwen3 models, the language model has extra tokens that do
# not exist in the tokenizer, and vice versa for multimodal
# placeholder tokens in some multimodal models.
# See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
# and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501
# Here we take the max of the two to determine if a token id is
# truly out-of-vocabulary.
model_vocab_size = model_config.get_vocab_size()
if max_input_id > max(tokenizer.max_token_id, model_vocab_size - 1):
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
def _validate_model_inputs(
self,
encoder_inputs: SingletonInputs | None,
decoder_inputs: SingletonInputs,
):
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")