[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
151
vllm/inputs/parse.py
Normal file
151
vllm/inputs/parse.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union, cast, overload
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
content: str
|
||||
is_tokens: Literal[False]
|
||||
|
||||
|
||||
class ParsedTokens(TypedDict):
|
||||
content: list[int]
|
||||
is_tokens: Literal[True]
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
|
||||
...
|
||||
|
||||
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, list[str], list[int], list[list[int]]],
|
||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [ParsedText(content=prompt, is_tokens=False)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if is_list_of(prompt, str):
|
||||
# case 2: array of strings
|
||||
prompt = cast(list[str], prompt)
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False) for elem in prompt
|
||||
]
|
||||
if is_list_of(prompt, int):
|
||||
# case 3: array of tokens
|
||||
prompt = cast(list[int], prompt)
|
||||
return [ParsedTokens(content=prompt, is_tokens=True)]
|
||||
if is_list_of(prompt, list):
|
||||
prompt = cast(list[list[int]], prompt)
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if is_list_of(prompt[0], int):
|
||||
# case 4: array of token arrays
|
||||
return [
|
||||
ParsedTokens(content=elem, is_tokens=True)
|
||||
for elem in prompt
|
||||
]
|
||||
|
||||
raise TypeError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
|
||||
|
||||
class ParsedStrPrompt(TypedDict):
|
||||
type: Literal["str"]
|
||||
content: str
|
||||
|
||||
|
||||
class ParsedTextPrompt(TypedDict):
|
||||
type: Literal["text"]
|
||||
content: TextPrompt
|
||||
|
||||
|
||||
class ParsedTokensPrompt(TypedDict):
|
||||
type: Literal["tokens"]
|
||||
content: TokensPrompt
|
||||
|
||||
|
||||
class ParsedEmbedsPrompt(TypedDict):
|
||||
type: Literal["embeds"]
|
||||
content: EmbedsPrompt
|
||||
|
||||
|
||||
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, ParsedEmbedsPrompt]
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
||||
...
|
||||
|
||||
|
||||
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
elif isinstance(prompt, dict):
|
||||
# Type ignores are because mypy does not correctly infer the TypedDicts
|
||||
# Pyright does succeed.
|
||||
if "prompt_embeds" in prompt:
|
||||
return ParsedEmbedsPrompt(
|
||||
type="embeds", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(
|
||||
type="tokens", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt" in prompt:
|
||||
return ParsedTextPrompt(type="text", content=prompt)
|
||||
raise TypeError(
|
||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||
if "encoder" in inputs and "decoder" in inputs:
|
||||
# NOTE: This passes pyright but not mypy
|
||||
return (
|
||||
inputs["encoder"], # type: ignore[typeddict-item]
|
||||
inputs["decoder"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return None, inputs
|
||||
Reference in New Issue
Block a user