412 lines
12 KiB
Python
412 lines
12 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from typing_extensions import NotRequired, TypedDict, assert_never
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from vllm.multimodal.inputs import (
|
||
|
|
MultiModalDataDict,
|
||
|
|
MultiModalEncDecInputs,
|
||
|
|
MultiModalInputs,
|
||
|
|
MultiModalUUIDDict,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
MultiModalDataDict = object
|
||
|
|
MultiModalEncDecInputs = object
|
||
|
|
MultiModalInputs = object
|
||
|
|
MultiModalUUIDDict = object
|
||
|
|
|
||
|
|
|
||
|
|
# Inputs to LLM API
|
||
|
|
class _PromptOptions(TypedDict):
|
||
|
|
"""
|
||
|
|
Additional options available to all
|
||
|
|
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
|
||
|
|
"""
|
||
|
|
|
||
|
|
multi_modal_data: NotRequired[MultiModalDataDict | None]
|
||
|
|
"""
|
||
|
|
Optional multi-modal data to pass to the model,
|
||
|
|
if the model supports it.
|
||
|
|
"""
|
||
|
|
|
||
|
|
mm_processor_kwargs: NotRequired[dict[str, Any] | None]
|
||
|
|
"""
|
||
|
|
Optional multi-modal processor kwargs to be forwarded to the
|
||
|
|
multimodal input mapper & processor. Note that if multiple modalities
|
||
|
|
have registered mappers etc for the model being considered, we attempt
|
||
|
|
to pass the mm_processor_kwargs to each of them.
|
||
|
|
"""
|
||
|
|
|
||
|
|
multi_modal_uuids: NotRequired[MultiModalUUIDDict]
|
||
|
|
"""
|
||
|
|
Optional user-specified UUIDs for multimodal items, mapped by modality.
|
||
|
|
Lists must match the number of items per modality and may contain `None`.
|
||
|
|
For `None` entries, the hasher will compute IDs automatically; non-None
|
||
|
|
entries override the default hashes for caching, and MUST be unique per
|
||
|
|
multimodal item.
|
||
|
|
"""
|
||
|
|
|
||
|
|
cache_salt: NotRequired[str]
|
||
|
|
"""
|
||
|
|
Optional cache salt to be used for prefix caching.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
class TextPrompt(_PromptOptions):
|
||
|
|
"""Schema for a text prompt."""
|
||
|
|
|
||
|
|
prompt: str
|
||
|
|
"""The input text to be tokenized before passing to the model."""
|
||
|
|
|
||
|
|
|
||
|
|
class TokensPrompt(_PromptOptions):
|
||
|
|
"""Schema for a tokenized prompt."""
|
||
|
|
|
||
|
|
prompt_token_ids: list[int]
|
||
|
|
"""A list of token IDs to pass to the model."""
|
||
|
|
|
||
|
|
prompt: NotRequired[str]
|
||
|
|
"""The prompt text corresponding to the token IDs, if available."""
|
||
|
|
|
||
|
|
token_type_ids: NotRequired[list[int]]
|
||
|
|
"""A list of token type IDs to pass to the cross encoder model."""
|
||
|
|
|
||
|
|
|
||
|
|
class EmbedsPrompt(_PromptOptions):
|
||
|
|
"""Schema for a prompt provided via token embeddings."""
|
||
|
|
|
||
|
|
prompt_embeds: torch.Tensor
|
||
|
|
"""The embeddings of the prompt."""
|
||
|
|
|
||
|
|
prompt: NotRequired[str]
|
||
|
|
"""The prompt text corresponding to the token embeddings, if available."""
|
||
|
|
|
||
|
|
|
||
|
|
DecoderOnlyPrompt: TypeAlias = (
|
||
|
|
str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
|
||
|
|
)
|
||
|
|
"""
|
||
|
|
Schema of a prompt for a decoder-only model:
|
||
|
|
|
||
|
|
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||
|
|
- A tokenized prompt (list of token IDs, or
|
||
|
|
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||
|
|
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
|
||
|
|
|
||
|
|
For encoder-decoder models, passing a singleton prompt is shorthand for passing
|
||
|
|
`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
|
||
|
|
"""
|
||
|
|
Schema of a prompt for the encoder part of a encoder-decoder model:
|
||
|
|
|
||
|
|
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||
|
|
- A tokenized prompt (list of token IDs, or
|
||
|
|
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
|
||
|
|
"""
|
||
|
|
Schema of a prompt for the decoder part of an encoder-decoder model:
|
||
|
|
|
||
|
|
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||
|
|
- A tokenized prompt (list of token IDs, or
|
||
|
|
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||
|
|
|
||
|
|
Note:
|
||
|
|
Multi-modal inputs are not supported for decoder prompts.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
class ExplicitEncoderDecoderPrompt(TypedDict):
|
||
|
|
"""
|
||
|
|
Schema for a pair of encoder and decoder singleton prompts.
|
||
|
|
|
||
|
|
Note:
|
||
|
|
This schema is not valid for decoder-only models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
encoder_prompt: EncoderPrompt
|
||
|
|
"""The prompt for the encoder part of the model."""
|
||
|
|
|
||
|
|
decoder_prompt: DecoderPrompt | None
|
||
|
|
"""
|
||
|
|
The prompt for the decoder part of the model.
|
||
|
|
|
||
|
|
Passing `None` will cause the prompt to be inferred automatically.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
|
||
|
|
"""
|
||
|
|
Schema for a prompt for an encoder-decoder model.
|
||
|
|
|
||
|
|
You can pass a singleton encoder prompt, in which case the decoder prompt is
|
||
|
|
considered to be `None` (i.e., infer automatically).
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
|
||
|
|
"""
|
||
|
|
Schema for a single prompt. This is as opposed to a data structure
|
||
|
|
which encapsulates multiple prompts, such as
|
||
|
|
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
|
||
|
|
"""
|
||
|
|
Schema for any prompt, regardless of model type.
|
||
|
|
|
||
|
|
This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
class DataPrompt(_PromptOptions):
|
||
|
|
"""
|
||
|
|
Represents generic inputs that are converted to
|
||
|
|
[`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
|
||
|
|
"""
|
||
|
|
|
||
|
|
data: Any
|
||
|
|
"""The input data."""
|
||
|
|
|
||
|
|
data_format: str
|
||
|
|
"""The input data format."""
|
||
|
|
|
||
|
|
|
||
|
|
# Outputs of processor
|
||
|
|
class _InputOptions(TypedDict):
|
||
|
|
"""
|
||
|
|
Additional options available to all input types.
|
||
|
|
"""
|
||
|
|
|
||
|
|
arrival_time: NotRequired[float]
|
||
|
|
"""The time when the input was received (before rendering)."""
|
||
|
|
|
||
|
|
cache_salt: NotRequired[str]
|
||
|
|
"""Optional cache salt to be used for prefix caching."""
|
||
|
|
|
||
|
|
|
||
|
|
class TokenInputs(_InputOptions):
|
||
|
|
"""Represents token-based inputs."""
|
||
|
|
|
||
|
|
type: Literal["token"]
|
||
|
|
"""The type of inputs."""
|
||
|
|
|
||
|
|
prompt_token_ids: list[int]
|
||
|
|
"""The token IDs of the prompt."""
|
||
|
|
|
||
|
|
prompt: NotRequired[str]
|
||
|
|
"""The prompt text corresponding to the token IDs, if available."""
|
||
|
|
|
||
|
|
|
||
|
|
def token_inputs(
|
||
|
|
prompt_token_ids: list[int],
|
||
|
|
*,
|
||
|
|
prompt: str | None = None,
|
||
|
|
cache_salt: str | None = None,
|
||
|
|
) -> TokenInputs:
|
||
|
|
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
|
||
|
|
values."""
|
||
|
|
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
|
||
|
|
|
||
|
|
if prompt is not None:
|
||
|
|
inputs["prompt"] = prompt
|
||
|
|
if cache_salt is not None:
|
||
|
|
inputs["cache_salt"] = cache_salt
|
||
|
|
|
||
|
|
return inputs
|
||
|
|
|
||
|
|
|
||
|
|
class EmbedsInputs(_InputOptions):
|
||
|
|
"""Represents embeddings-based inputs."""
|
||
|
|
|
||
|
|
type: Literal["embeds"]
|
||
|
|
"""The type of inputs."""
|
||
|
|
|
||
|
|
prompt_embeds: torch.Tensor
|
||
|
|
"""The embeddings of the prompt."""
|
||
|
|
|
||
|
|
prompt: NotRequired[str]
|
||
|
|
"""The prompt text corresponding to the token IDs, if available."""
|
||
|
|
|
||
|
|
|
||
|
|
def embeds_inputs(
|
||
|
|
prompt_embeds: torch.Tensor,
|
||
|
|
*,
|
||
|
|
prompt: str | None = None,
|
||
|
|
cache_salt: str | None = None,
|
||
|
|
) -> EmbedsInputs:
|
||
|
|
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
|
||
|
|
values."""
|
||
|
|
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
|
||
|
|
|
||
|
|
if prompt is not None:
|
||
|
|
inputs["prompt"] = prompt
|
||
|
|
if cache_salt is not None:
|
||
|
|
inputs["cache_salt"] = cache_salt
|
||
|
|
|
||
|
|
return inputs
|
||
|
|
|
||
|
|
|
||
|
|
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
|
||
|
|
"""
|
||
|
|
A processed prompt from
|
||
|
|
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||
|
|
which can be passed to
|
||
|
|
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||
|
|
for decoder-only models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
|
||
|
|
"""
|
||
|
|
A processed encoder prompt from
|
||
|
|
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||
|
|
which can be passed to
|
||
|
|
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||
|
|
for encoder-decoder models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
|
||
|
|
"""
|
||
|
|
A processed decoder prompt from
|
||
|
|
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||
|
|
which can be passed to
|
||
|
|
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||
|
|
for encoder-decoder models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
class EncoderDecoderInputs(TypedDict):
|
||
|
|
"""
|
||
|
|
A processed pair of encoder and decoder singleton prompts.
|
||
|
|
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||
|
|
which can be passed to
|
||
|
|
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||
|
|
for encoder-decoder models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
type: Literal["enc_dec"]
|
||
|
|
|
||
|
|
encoder_prompt: EncoderInputs
|
||
|
|
"""The inputs for the encoder portion."""
|
||
|
|
|
||
|
|
decoder_prompt: DecoderInputs
|
||
|
|
"""The inputs for the decoder portion."""
|
||
|
|
|
||
|
|
arrival_time: NotRequired[float]
|
||
|
|
"""The time when the input was received (before rendering)."""
|
||
|
|
|
||
|
|
|
||
|
|
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
|
||
|
|
"""
|
||
|
|
A processed prompt from
|
||
|
|
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||
|
|
which can be passed to
|
||
|
|
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
|
||
|
|
"""The inputs for a single encoder/decoder prompt."""
|
||
|
|
|
||
|
|
|
||
|
|
def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs:
|
||
|
|
if inputs["type"] == "embeds":
|
||
|
|
raise ValueError(
|
||
|
|
"Embedding inputs are not supported for encoder-decoder models"
|
||
|
|
)
|
||
|
|
|
||
|
|
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
|
||
|
|
raise RuntimeError(
|
||
|
|
"You should register an encoder-decoder multi-modal processor "
|
||
|
|
"for encoder-decoder models."
|
||
|
|
)
|
||
|
|
|
||
|
|
return inputs # type: ignore[return-value]
|
||
|
|
|
||
|
|
|
||
|
|
def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs:
|
||
|
|
if inputs["type"] == "embeds":
|
||
|
|
raise ValueError(
|
||
|
|
"Embedding inputs are not supported for encoder-decoder models"
|
||
|
|
)
|
||
|
|
|
||
|
|
return inputs
|
||
|
|
|
||
|
|
|
||
|
|
def _prepare_decoder_input_ids_for_generation(
|
||
|
|
decoder_input_ids: list[int],
|
||
|
|
decoder_start_token_id: int,
|
||
|
|
) -> list[int]:
|
||
|
|
"""
|
||
|
|
Prepare `decoder_input_ids` for generation with encoder-decoder models,
|
||
|
|
according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
|
||
|
|
|
||
|
|
Source:
|
||
|
|
https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py
|
||
|
|
"""
|
||
|
|
if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id:
|
||
|
|
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||
|
|
|
||
|
|
return decoder_input_ids
|
||
|
|
|
||
|
|
|
||
|
|
def build_enc_dec_inputs(
|
||
|
|
encoder_inputs: SingletonInputs,
|
||
|
|
decoder_inputs: SingletonInputs | None,
|
||
|
|
decoder_start_token_id: int,
|
||
|
|
) -> EncoderDecoderInputs:
|
||
|
|
enc_inputs = _validate_enc_inputs(encoder_inputs)
|
||
|
|
|
||
|
|
if decoder_inputs is None:
|
||
|
|
dec_inputs: DecoderInputs = enc_inputs
|
||
|
|
else:
|
||
|
|
dec_inputs = _validate_dec_inputs(decoder_inputs)
|
||
|
|
|
||
|
|
enc_inputs_new: EncoderInputs
|
||
|
|
dec_inputs_new: DecoderInputs
|
||
|
|
|
||
|
|
if enc_inputs["type"] == "multimodal":
|
||
|
|
from vllm.multimodal.inputs import mm_inputs
|
||
|
|
|
||
|
|
enc_inputs_new = token_inputs(
|
||
|
|
enc_inputs["encoder_prompt_token_ids"],
|
||
|
|
prompt=enc_inputs.get("encoder_prompt"),
|
||
|
|
)
|
||
|
|
dec_inputs_new = mm_inputs(
|
||
|
|
prompt_token_ids=dec_inputs["prompt_token_ids"],
|
||
|
|
prompt=dec_inputs.get("prompt"),
|
||
|
|
mm_kwargs=enc_inputs["mm_kwargs"],
|
||
|
|
mm_hashes=enc_inputs["mm_hashes"],
|
||
|
|
mm_placeholders=enc_inputs["mm_placeholders"],
|
||
|
|
)
|
||
|
|
elif enc_inputs["type"] == "token":
|
||
|
|
enc_inputs_new = token_inputs(prompt_token_ids=[])
|
||
|
|
dec_inputs_new = dec_inputs
|
||
|
|
else:
|
||
|
|
assert_never(enc_inputs)
|
||
|
|
|
||
|
|
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
|
||
|
|
dec_inputs_new["prompt_token_ids"],
|
||
|
|
decoder_start_token_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
if cache_salt := enc_inputs.get("cache_salt"):
|
||
|
|
dec_inputs_new["cache_salt"] = cache_salt
|
||
|
|
|
||
|
|
return EncoderDecoderInputs(
|
||
|
|
type="enc_dec",
|
||
|
|
encoder_prompt=enc_inputs_new,
|
||
|
|
decoder_prompt=dec_inputs_new,
|
||
|
|
)
|