55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
|
|
|
|
from collections.abc import Iterable
|
|
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
|
|
|
import torch
|
|
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
|
|
MultiModalUUIDDict)
|
|
|
|
class EmbedsPrompt(TypedDict):
|
|
"""Schema for a prompt provided via token embeddings."""
|
|
|
|
prompt_embeds: torch.Tensor
|
|
"""The embeddings of the prompt."""
|
|
from vllm.sequence import IntermediateTensors
|
|
deepstack_input_embeds: Optional[IntermediateTensors]
|
|
cache_salt: NotRequired[str]
|
|
"""
|
|
Optional cache salt to be used for prefix caching.
|
|
"""
|
|
|
|
|
|
|
|
class EmbedsInputs(TypedDict):
|
|
"""Represents embeddings-based inputs."""
|
|
|
|
type: Literal["embeds"]
|
|
"""The type of inputs."""
|
|
|
|
prompt_embeds: torch.Tensor
|
|
"""The embeddings of the prompt."""
|
|
deepstack_input_embeds: torch.Tensor
|
|
|
|
cache_salt: NotRequired[str]
|
|
"""
|
|
Optional cache salt to be used for prefix caching.
|
|
"""
|
|
|
|
|
|
def embeds_inputs(
|
|
prompt_embeds: torch.Tensor,
|
|
deepstack_input_embeds: torch.Tensor,
|
|
cache_salt: Optional[str] = None,
|
|
) -> EmbedsInputs:
|
|
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
|
|
values."""
|
|
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds)
|
|
|
|
if cache_salt is not None:
|
|
inputs["cache_salt"] = cache_salt
|
|
|
|
return inputs |