128 lines
4.8 KiB
Python
128 lines
4.8 KiB
Python
|
|
|
|
import asyncio
|
|
import io
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, Optional, Union
|
|
|
|
import pybase64
|
|
import torch
|
|
from pydantic import Field
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
|
from vllm.inputs.parse import parse_and_batch_prompt
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils import AsyncMicrobatchTokenizer
|
|
|
|
|
|
|
|
class BaseRenderer(ABC):
|
|
"""
|
|
Base class for unified input processing and rendering.
|
|
|
|
The Renderer serves as a unified input processor that consolidates
|
|
tokenization, chat template formatting, and multimodal input handling
|
|
into a single component.
|
|
It converts high-level API requests (OpenAI-style JSON) into token IDs and
|
|
multimodal features ready for engine consumption.
|
|
|
|
Key responsibilities:
|
|
- Convert text prompts to token sequences with proper special tokens
|
|
- Apply chat templates and format conversations
|
|
- Handle multimodal inputs (images, audio, etc.) when applicable
|
|
- Manage prompt truncation and length validation
|
|
- Provide clean separation between API layer and engine core
|
|
"""
|
|
|
|
@classmethod
|
|
def load_prompt_embeds(
|
|
cls,
|
|
prompt_embeds: Union[bytes, list[bytes]],
|
|
deepstack_input_embeds: Optional[dict[str, Union[bytes, str]]] = None,
|
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
|
|
cache_salt: Optional[str] = None,
|
|
) -> list[EngineEmbedsPrompt]:
|
|
"""Load and validate base64-encoded embeddings into prompt objects."""
|
|
|
|
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
|
|
tensor = torch.load(
|
|
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
|
weights_only=True,
|
|
map_location=torch.device("cpu"),
|
|
)
|
|
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
|
torch.float32,
|
|
torch.bfloat16,
|
|
torch.float16,
|
|
)
|
|
tensor = tensor.to_dense()
|
|
if tensor.dim() > 2:
|
|
tensor = tensor.squeeze(0)
|
|
assert tensor.dim() == 2
|
|
if truncate_prompt_tokens is not None:
|
|
tensor = tensor[-truncate_prompt_tokens:]
|
|
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
|
|
if cache_salt is not None:
|
|
embeds_prompt["cache_salt"] = cache_salt
|
|
|
|
if deepstack_input_embeds is not None:
|
|
all_tensor = []
|
|
from vllm.sequence import IntermediateTensors
|
|
tensor_dict = torch.load(
|
|
io.BytesIO(pybase64.b64decode(deepstack_input_embeds, validate=True))
|
|
)
|
|
for k in tensor_dict:
|
|
all_tensor.append(tensor_dict[k].unsqueeze(0))
|
|
|
|
all_tensor = torch.concatenate(all_tensor, 0)
|
|
embeds_prompt["deepstack_input_embeds"] = all_tensor #IntermediateTensors(tensors=tensor_dict)
|
|
|
|
return embeds_prompt
|
|
|
|
if isinstance(prompt_embeds, list):
|
|
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
|
|
|
return [_load_and_validate_embed(prompt_embeds)]
|
|
|
|
|
|
|
|
class CompletionRenderer(BaseRenderer):
|
|
|
|
async def render_prompt_and_embeds(
|
|
self,
|
|
*,
|
|
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
|
list[list[int]]]] = None,
|
|
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
|
deepstack_input_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
|
config: "RenderConfig",
|
|
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
|
"""
|
|
Render text/token prompts and/or precomputed embedding prompts. At
|
|
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
|
"""
|
|
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
|
config.truncate_prompt_tokens, config.max_length)
|
|
if truncate_prompt_tokens == 0:
|
|
return []
|
|
|
|
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
|
|
|
|
if prompt_embeds is not None:
|
|
rendered.extend(
|
|
self.load_prompt_embeds(prompt_embeds, deepstack_input_embeds, truncate_prompt_tokens,
|
|
config.cache_salt))
|
|
if prompt_or_prompts is None or prompt_or_prompts == "":
|
|
return rendered
|
|
|
|
token_prompts = await self.render_prompt(
|
|
prompt_or_prompts=prompt_or_prompts,
|
|
config=config,
|
|
)
|
|
rendered.extend(token_prompts)
|
|
|
|
return rendered
|