""" Schemas and utilites for preprocessing inputs. """ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload from vllm.inputs import ( EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, ) from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.collection_utils import is_list_of if TYPE_CHECKING: import torch from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @overload def prompt_to_seq( prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes], ) -> Sequence[SingletonPrompt]: ... @overload def prompt_to_seq( # type: ignore[misc] prompt_or_prompts: ExplicitEncoderDecoderPrompt | Sequence[ExplicitEncoderDecoderPrompt], ) -> Sequence[ExplicitEncoderDecoderPrompt]: ... @overload def prompt_to_seq( # type: ignore[misc] prompt_or_prompts: PromptType | Sequence[PromptType], ) -> Sequence[PromptType]: ... def prompt_to_seq( prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes], ) -> Sequence[PromptType]: if isinstance(prompt_or_prompts, (dict, str, bytes)) or ( len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int) ): return [prompt_or_prompts] # type: ignore[list-item] return prompt_or_prompts # type: ignore[return-value] def conversation_to_seq( conversation_or_conversations: list["ChatCompletionMessageParam"] | Sequence[list["ChatCompletionMessageParam"]], ) -> Sequence[list["ChatCompletionMessageParam"]]: if len(conversation_or_conversations) > 0 and is_list_of( conversation_or_conversations, dict ): return [conversation_or_conversations] # type: ignore[list-item] return conversation_or_conversations # type: ignore[return-value] DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt """ A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt] that has been standardized into a dictionary. """ EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt """ A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt] that has been standardized into a dictionary. """ DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt """ A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt] that has been standardized into a dictionary. """ class EncoderDecoderDictPrompt(TypedDict): """ A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt] that has been standardized into a dictionary. """ encoder_prompt: EncoderDictPrompt decoder_prompt: DecoderDictPrompt | None SingletonDictPrompt: TypeAlias = ( DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt ) """ A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] that has been standardized into a dictionary. """ DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt """ A [`PromptType`][vllm.inputs.data.PromptType] that has been standardized into a dictionary. """ def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt: """ Parse a prompt for a decoder-only model and normalize it to a dictionary. """ if isinstance(prompt, str): return TextPrompt(prompt=prompt) if isinstance(prompt, list): if not is_list_of(prompt, int): raise TypeError("Token prompt should be a list of integers") return TokensPrompt(prompt_token_ids=prompt) if isinstance(prompt, dict): if "encoder_prompt" in prompt: raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models") if ( "prompt" in prompt or "prompt_token_ids" in prompt or "prompt_embeds" in prompt ): return prompt # type: ignore[return-value] raise TypeError("Prompt dictionary must contain text, tokens, or embeddings") raise TypeError("Prompt should be a string, list of tokens, or dictionary") def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt: if isinstance(prompt, str): return TextPrompt(prompt=prompt) if isinstance(prompt, list): if not is_list_of(prompt, int): raise TypeError("Token prompt should be a list of integers") return TokensPrompt(prompt_token_ids=prompt) if isinstance(prompt, dict): if "prompt_embeds" in prompt: raise TypeError("Cannot pass embeddings prompt to encoder-decoder models") if "prompt" in prompt or "prompt_token_ids" in prompt: return prompt # type: ignore[return-value] raise TypeError("Prompt dictionary must contain text or tokens") raise TypeError("Prompt should be a string, list of tokens, or dictionary") def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt: if isinstance(prompt, str): return TextPrompt(prompt=prompt) if isinstance(prompt, list): if not is_list_of(prompt, int): raise TypeError("Token prompt should be a list of integers") return TokensPrompt(prompt_token_ids=prompt) if isinstance(prompt, dict): if "prompt_embeds" in prompt: raise TypeError("Cannot pass embeddings prompt to encoder-decoder models") if ( "multi_modal_data" in prompt or "mm_processor_kwargs" in prompt or "multi_modal_uuids" in prompt ): raise TypeError("Cannot pass multi-modal inputs to decoder prompt") if "prompt" in prompt or "prompt_token_ids" in prompt: return prompt # type: ignore[return-value] raise TypeError("Prompt dictionary must contain text or tokens") raise TypeError("Prompt should be a string, list of tokens, or dictionary") def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt: """ Parse a prompt for an encoder-decoder model and normalize it to a dictionary. """ if isinstance(prompt, dict) and "encoder_prompt" in prompt: enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item] dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item] else: enc_prompt = prompt dec_prompt = None return EncoderDecoderDictPrompt( encoder_prompt=_parse_enc_prompt(enc_prompt), decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt), ) def parse_model_prompt(model_config: "ModelConfig", prompt: object): if model_config.is_encoder_decoder: return parse_enc_dec_prompt(prompt) return parse_dec_only_prompt(prompt) class PromptComponents(NamedTuple): text: str | None = None token_ids: list[int] | None = None embeds: "torch.Tensor | None" = None def extract_target_prompt(model_config: "ModelConfig", prompt: object): return ( parse_enc_dec_prompt(prompt)["encoder_prompt"] if model_config.is_encoder_decoder else parse_dec_only_prompt(prompt) ) def extract_prompt_components( model_config: "ModelConfig", prompt: PromptType | ProcessorInputs, ) -> PromptComponents: target_prompt = extract_target_prompt(model_config, prompt) return PromptComponents( text=target_prompt.get("prompt"), token_ids=target_prompt.get("prompt_token_ids"), embeds=target_prompt.get("prompt_embeds"), ) def extract_prompt_len( model_config: "ModelConfig", prompt: PromptType | ProcessorInputs ): target_prompt = extract_target_prompt(model_config, prompt) return length_from_prompt_token_ids_or_embeds( target_prompt.get("prompt_token_ids"), target_prompt.get("prompt_embeds"), )