55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
|
|
|
|
from collections.abc import Mapping
|
|
from typing import Any, Optional, Union, cast
|
|
|
|
from typing_extensions import assert_never
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
|
MultiModalInputs, MultiModalUUIDDict)
|
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
|
|
from .data import EmbedsInputs, EmbedsPrompt, embeds_inputs
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class InputPreprocessor:
|
|
def _process_embeds(
|
|
self,
|
|
parsed_content: EmbedsPrompt,
|
|
) -> EmbedsInputs:
|
|
if not self.model_config.enable_prompt_embeds:
|
|
raise ValueError("You must set `--enable-prompt-embeds` to input "
|
|
"`prompt_embeds`.")
|
|
|
|
prompt_embeds = parsed_content["prompt_embeds"]
|
|
deepstack_input_embeds = None
|
|
if 'deepstack_input_embeds' in parsed_content:
|
|
deepstack_input_embeds = parsed_content["deepstack_input_embeds"]
|
|
|
|
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
|
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
|
# we can unambiguously process the intent by squeezing the batch
|
|
# dimension.
|
|
if prompt_embeds.ndim == 3:
|
|
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
|
|
|
if prompt_embeds.ndim != 2:
|
|
raise ValueError(
|
|
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
|
|
|
# Tensors must be on CPU for serialization between processes
|
|
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
|
# hidden device transfer in the critical path of generation.
|
|
prompt_embeds = prompt_embeds.cpu()
|
|
|
|
return embeds_inputs(prompt_embeds=prompt_embeds,
|
|
deepstack_input_embeds=deepstack_input_embeds,
|
|
cache_salt=parsed_content.get("cache_salt"))
|