from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast import torch from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs, MultiModalUUIDDict) class EmbedsPrompt(TypedDict): """Schema for a prompt provided via token embeddings.""" prompt_embeds: torch.Tensor """The embeddings of the prompt.""" from vllm.sequence import IntermediateTensors deepstack_input_embeds: Optional[IntermediateTensors] cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. """ class EmbedsInputs(TypedDict): """Represents embeddings-based inputs.""" type: Literal["embeds"] """The type of inputs.""" prompt_embeds: torch.Tensor """The embeddings of the prompt.""" deepstack_input_embeds: torch.Tensor cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. """ def embeds_inputs( prompt_embeds: torch.Tensor, deepstack_input_embeds: torch.Tensor, cache_salt: Optional[str] = None, ) -> EmbedsInputs: """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional values.""" inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds) if cache_salt is not None: inputs["cache_salt"] = cache_salt return inputs