Sync from v0.13
This commit is contained in:
146
vllm/inputs/parse.py
Normal file
146
vllm/inputs/parse.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
from .data import (
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
def parse_raw_prompts(
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [TextPrompt(prompt=prompt)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
# case 2: array of strings
|
||||
if is_list_of(prompt, str):
|
||||
prompt = cast(list[str], prompt)
|
||||
return [TextPrompt(prompt=elem) for elem in prompt]
|
||||
|
||||
# case 3: array of tokens
|
||||
if is_list_of(prompt, int):
|
||||
prompt = cast(list[int], prompt)
|
||||
return [TokensPrompt(prompt_token_ids=prompt)]
|
||||
|
||||
# case 4: array of token arrays
|
||||
if is_list_of(prompt, list):
|
||||
first = prompt[0]
|
||||
if not isinstance(first, list):
|
||||
raise ValueError("prompt expected to be a list of lists")
|
||||
|
||||
if len(first) == 0:
|
||||
raise ValueError("Please provide at least one prompt")
|
||||
|
||||
# strict validation: every nested list must be list[int]
|
||||
if not all(is_list_of(elem, int) for elem in prompt):
|
||||
raise TypeError("Nested lists must contain only integers")
|
||||
|
||||
prompt = cast(list[list[int]], prompt)
|
||||
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||
|
||||
raise TypeError(
|
||||
"prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays"
|
||||
)
|
||||
|
||||
|
||||
class ParsedStrPrompt(TypedDict):
|
||||
type: Literal["str"]
|
||||
content: str
|
||||
|
||||
|
||||
class ParsedTextPrompt(TypedDict):
|
||||
type: Literal["text"]
|
||||
content: TextPrompt
|
||||
|
||||
|
||||
class ParsedTokensPrompt(TypedDict):
|
||||
type: Literal["tokens"]
|
||||
content: TokensPrompt
|
||||
|
||||
|
||||
class ParsedEmbedsPrompt(TypedDict):
|
||||
type: Literal["embeds"]
|
||||
content: EmbedsPrompt
|
||||
|
||||
|
||||
ParsedSingletonPrompt: TypeAlias = (
|
||||
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
|
||||
)
|
||||
|
||||
|
||||
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
elif isinstance(prompt, dict):
|
||||
# Type ignores are because mypy does not correctly infer the TypedDicts
|
||||
# Pyright does succeed.
|
||||
if "prompt_embeds" in prompt:
|
||||
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt" in prompt:
|
||||
return ParsedTextPrompt(type="text", content=prompt)
|
||||
raise TypeError(
|
||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
|
||||
)
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType,
|
||||
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[SingletonInputs | None, SingletonInputs]:
|
||||
if "encoder" in inputs and "decoder" in inputs:
|
||||
# NOTE: This passes pyright but not mypy
|
||||
return (
|
||||
inputs["encoder"], # type: ignore[typeddict-item]
|
||||
inputs["decoder"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return None, inputs
|
||||
|
||||
|
||||
class PromptComponents(NamedTuple):
|
||||
text: str | None = None
|
||||
token_ids: list[int] | None = None
|
||||
embeds: "torch.Tensor | None" = None
|
||||
|
||||
|
||||
def get_prompt_components(prompt: PromptType) -> PromptComponents:
|
||||
if isinstance(prompt, str):
|
||||
return PromptComponents(text=prompt)
|
||||
|
||||
if encoder_prompt := prompt.get("encoder_prompt"):
|
||||
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
|
||||
|
||||
return PromptComponents(
|
||||
text=prompt.get("prompt"), # type: ignore[arg-type]
|
||||
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
embeds=prompt.get("prompt_embeds"),
|
||||
)
|
||||
Reference in New Issue
Block a user