from collections.abc import Mapping from typing import Any, Optional, Union, cast from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs, MultiModalUUIDDict) from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer from .data import EmbedsInputs, EmbedsPrompt, embeds_inputs logger = init_logger(__name__) class InputPreprocessor: def _process_embeds( self, parsed_content: EmbedsPrompt, ) -> EmbedsInputs: if not self.model_config.enable_prompt_embeds: raise ValueError("You must set `--enable-prompt-embeds` to input " "`prompt_embeds`.") prompt_embeds = parsed_content["prompt_embeds"] deepstack_input_embeds = None if 'deepstack_input_embeds' in parsed_content: deepstack_input_embeds = parsed_content["deepstack_input_embeds"] # prompt_embeds must be (seq_len, hidden_size), but if the user # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), # we can unambiguously process the intent by squeezing the batch # dimension. if prompt_embeds.ndim == 3: prompt_embeds = prompt_embeds.squeeze(dim=0) if prompt_embeds.ndim != 2: raise ValueError( "prompt_embeds must be of shape (seq_len, hidden_size).") # Tensors must be on CPU for serialization between processes # in the MsgpackEncoder. Casting to CPU here ensures that there is no # hidden device transfer in the critical path of generation. prompt_embeds = prompt_embeds.cpu() return embeds_inputs(prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds, cache_salt=parsed_content.get("cache_salt"))