147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Sequence
|
|
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
|
|
|
|
from typing_extensions import TypeIs
|
|
|
|
from vllm.utils.collection_utils import is_list_of
|
|
|
|
from .data import (
|
|
EmbedsPrompt,
|
|
ExplicitEncoderDecoderPrompt,
|
|
ProcessorInputs,
|
|
PromptType,
|
|
SingletonInputs,
|
|
SingletonPrompt,
|
|
TextPrompt,
|
|
TokensPrompt,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
import torch
|
|
|
|
|
|
def parse_raw_prompts(
|
|
prompt: str | list[str] | list[int] | list[list[int]],
|
|
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
|
|
if isinstance(prompt, str):
|
|
# case 1: a string
|
|
return [TextPrompt(prompt=prompt)]
|
|
|
|
if isinstance(prompt, list):
|
|
if len(prompt) == 0:
|
|
raise ValueError("please provide at least one prompt")
|
|
|
|
# case 2: array of strings
|
|
if is_list_of(prompt, str):
|
|
prompt = cast(list[str], prompt)
|
|
return [TextPrompt(prompt=elem) for elem in prompt]
|
|
|
|
# case 3: array of tokens
|
|
if is_list_of(prompt, int):
|
|
prompt = cast(list[int], prompt)
|
|
return [TokensPrompt(prompt_token_ids=prompt)]
|
|
|
|
# case 4: array of token arrays
|
|
if is_list_of(prompt, list):
|
|
first = prompt[0]
|
|
if not isinstance(first, list):
|
|
raise ValueError("prompt expected to be a list of lists")
|
|
|
|
if len(first) == 0:
|
|
raise ValueError("Please provide at least one prompt")
|
|
|
|
# strict validation: every nested list must be list[int]
|
|
if not all(is_list_of(elem, int) for elem in prompt):
|
|
raise TypeError("Nested lists must contain only integers")
|
|
|
|
prompt = cast(list[list[int]], prompt)
|
|
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
|
|
|
raise TypeError(
|
|
"prompt must be a string, array of strings, "
|
|
"array of tokens, or array of token arrays"
|
|
)
|
|
|
|
|
|
class ParsedStrPrompt(TypedDict):
|
|
type: Literal["str"]
|
|
content: str
|
|
|
|
|
|
class ParsedTextPrompt(TypedDict):
|
|
type: Literal["text"]
|
|
content: TextPrompt
|
|
|
|
|
|
class ParsedTokensPrompt(TypedDict):
|
|
type: Literal["tokens"]
|
|
content: TokensPrompt
|
|
|
|
|
|
class ParsedEmbedsPrompt(TypedDict):
|
|
type: Literal["embeds"]
|
|
content: EmbedsPrompt
|
|
|
|
|
|
ParsedSingletonPrompt: TypeAlias = (
|
|
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
|
|
)
|
|
|
|
|
|
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
|
if isinstance(prompt, str):
|
|
return ParsedStrPrompt(type="str", content=prompt)
|
|
elif isinstance(prompt, dict):
|
|
# Type ignores are because mypy does not correctly infer the TypedDicts
|
|
# Pyright does succeed.
|
|
if "prompt_embeds" in prompt:
|
|
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
|
|
elif "prompt_token_ids" in prompt:
|
|
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
|
|
elif "prompt" in prompt:
|
|
return ParsedTextPrompt(type="text", content=prompt)
|
|
raise TypeError(
|
|
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
|
|
)
|
|
|
|
|
|
def is_explicit_encoder_decoder_prompt(
|
|
prompt: PromptType,
|
|
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
|
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
|
|
|
|
|
def split_enc_dec_inputs(
|
|
inputs: ProcessorInputs,
|
|
) -> tuple[SingletonInputs | None, SingletonInputs]:
|
|
if "encoder" in inputs and "decoder" in inputs:
|
|
# NOTE: This passes pyright but not mypy
|
|
return (
|
|
inputs["encoder"], # type: ignore[typeddict-item]
|
|
inputs["decoder"], # type: ignore[typeddict-item]
|
|
)
|
|
|
|
return None, inputs
|
|
|
|
|
|
class PromptComponents(NamedTuple):
|
|
text: str | None = None
|
|
token_ids: list[int] | None = None
|
|
embeds: "torch.Tensor | None" = None
|
|
|
|
|
|
def get_prompt_components(prompt: PromptType) -> PromptComponents:
|
|
if isinstance(prompt, str):
|
|
return PromptComponents(text=prompt)
|
|
|
|
if encoder_prompt := prompt.get("encoder_prompt"):
|
|
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
|
|
|
|
return PromptComponents(
|
|
text=prompt.get("prompt"), # type: ignore[arg-type]
|
|
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
|
embeds=prompt.get("prompt_embeds"),
|
|
)
|