import time from collections.abc import Mapping from typing import Any, Literal, Optional, Union import torch from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_lm_format_enforcer import ( validate_structured_output_request_lm_format_enforcer) from vllm.v1.structured_output.backend_outlines import ( validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( validate_xgrammar_grammar) logger = init_logger(__name__) class Processor: def process_inputs( self, request_id: str, prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. self._validate_lora(lora_request) self._validate_params(params) data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if data_parallel_rank is not None and not (0 <= data_parallel_rank < data_parallel_size): raise ValueError(f"data_parallel_rank {data_parallel_rank} " f"is out of range [0, {data_parallel_size}).") if arrival_time is None: arrival_time = time.time() # Optionally generate multimodal hash overrides to avoid hashing # multimodal data items by their content as their identifiers. # NOTE: when users explicitly turn off BOTH prefix caching and input # processing caching, no multimodal features or embeddings will be # reused across requests, therefore identifying multimodal data items # by their content is no longer necessary, and we create uuids with # request id-modality-index as multimodal hash overrides. if (self.model_config.multimodal_config and self.model_config.multimodal_config.mm_processor_cache_gb == 0 and not self.cache_config.enable_prefix_caching): mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) else: # Otherwise, use user-provided uuids as multimodal hash overrides # if provided. self._validate_multi_modal_uuids(prompt) if isinstance(prompt, dict): mm_uuids = prompt.get("multi_modal_uuids") else: mm_uuids = None # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) from vllm.platforms import current_platform current_platform.validate_request( prompt=prompt, params=params, processed_inputs=processed_inputs, ) eos_token_id = self.input_preprocessor.get_eos_token_id() encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs) # Mypy does not always properly infer the types of some elements of # discriminated unions of TypedDicts, because of how it handles # inheritance of TypedDict. If we explicitly extract the items we want # we can avoid type errors from using `dict.get` later in the method. prompt_str: Optional[str] = None if decoder_inputs[ "type"] == "embeds" else decoder_inputs.get("prompt") prompt_token_ids = decoder_inputs[ "prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[ "type"] == "embeds" else None deepstack_input_embeds = decoder_inputs["deepstack_input_embeds"] if decoder_inputs[ "type"] == "embeds" else None # for deepstack_input_embeds in llm.generate method if isinstance(deepstack_input_embeds, dict): all_tensors = [] for key in deepstack_input_embeds: if isinstance(deepstack_input_embeds[key], torch.Tensor): all_tensors.append(deepstack_input_embeds[key].unsqueeze(0)) if len(all_tensors) > 0: deepstack_input_embeds = torch.concatenate(all_tensors, 0) sampling_params = None pooling_params = None if isinstance(params, SamplingParams): # TODO: can we avoid cloning here in multiproc case? sampling_params = params.clone() # If unset max tokens, then generate up to the max_model_len. if sampling_params.max_tokens is None: seq_len = length_from_prompt_token_ids_or_embeds( prompt_token_ids, prompt_embeds) sampling_params.max_tokens = \ self.model_config.max_model_len - seq_len sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) if self.tokenizer is not None: sampling_params.update_from_tokenizer(self.tokenizer) else: pooling_params = params.clone() # Multimodal related. mm_features: Optional[list[MultiModalFeatureSpec]] = None if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_positions = decoder_inputs["mm_placeholders"] decoder_mm_hashes = decoder_inputs["mm_hashes"] # Merge and flatten multimodal placeholders, hashes and inputs # from dictionaries to lists, and sort them by each item's position # in the input sequence. sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) mm_features = [] for modality, idx in sorted_mm_idxs: mm_features.append( MultiModalFeatureSpec( data=decoder_mm_inputs[modality][idx], modality=modality, identifier=decoder_mm_hashes[modality][idx], mm_position=decoder_mm_positions[modality][idx])) return prompt_str, EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_token_ids, prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds, mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, cache_salt=decoder_inputs.get("cache_salt"), priority=priority, data_parallel_rank=data_parallel_rank, trace_headers=trace_headers, )