Files
2026-01-19 10:38:50 +08:00

147 lines
4.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
from typing_extensions import TypeIs
from vllm.utils.collection_utils import is_list_of
from .data import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
if TYPE_CHECKING:
import torch
def parse_raw_prompts(
prompt: str | list[str] | list[int] | list[list[int]],
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
if isinstance(prompt, str):
# case 1: a string
return [TextPrompt(prompt=prompt)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
# case 2: array of strings
if is_list_of(prompt, str):
prompt = cast(list[str], prompt)
return [TextPrompt(prompt=elem) for elem in prompt]
# case 3: array of tokens
if is_list_of(prompt, int):
prompt = cast(list[int], prompt)
return [TokensPrompt(prompt_token_ids=prompt)]
# case 4: array of token arrays
if is_list_of(prompt, list):
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")
if len(first) == 0:
raise ValueError("Please provide at least one prompt")
# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
class ParsedTextPrompt(TypedDict):
type: Literal["text"]
content: TextPrompt
class ParsedTokensPrompt(TypedDict):
type: Literal["tokens"]
content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal["embeds"]
content: EmbedsPrompt
ParsedSingletonPrompt: TypeAlias = (
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
)
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
)
def is_explicit_encoder_decoder_prompt(
prompt: PromptType,
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[SingletonInputs | None, SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if encoder_prompt := prompt.get("encoder_prompt"):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)