This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .base import BaseRenderer
from .params import ChatParams, TokenizeParams, merge_kwargs
from .registry import RendererRegistry, renderer_from_config
__all__ = [
"BaseRenderer",
"RendererRegistry",
"renderer_from_config",
"ChatParams",
"TokenizeParams",
"merge_kwargs",
]

767
vllm/renderers/base.py Normal file
View File

@@ -0,0 +1,767 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar
from vllm.inputs import (
EmbedsInputs,
EmbedsPrompt,
EncoderDecoderInputs,
ProcessorInputs,
SingletonInputs,
TextPrompt,
TokenInputs,
TokensPrompt,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
from .embed_utils import safe_load_prompt_embeds
from .inputs import (
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt,
SingletonDictPrompt,
SingletonTokPrompt,
TokPrompt,
)
from .inputs.preprocess import extract_target_prompt
from .params import ChatParams, TokenizeParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
)
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__)
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
class BaseRenderer(ABC, Generic[_T]):
@classmethod
@abstractmethod
def from_config(
cls,
config: "VllmConfig",
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
raise NotImplementedError
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
super().__init__()
self.config = config
self.model_config = config.model_config
self.tokenizer = tokenizer
# Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.registry import MultiModalTimingRegistry
mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor(
config.model_config,
tokenizer=tokenizer,
cache=mm_processor_cache,
)
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
self._mm_timing_registry = MultiModalTimingRegistry(
config.observability_config
)
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
if self._async_tokenizer is None:
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
return self._async_tokenizer
def get_mm_processor(self) -> "BaseMultiModalProcessor":
if self.mm_processor is None:
raise ValueError("Multi-modal processor not available for text-only models")
return self.mm_processor
@property
def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
if self.mm_processor is None:
return None
return self.mm_processor.cache
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self._mm_cache_stats
if mm_cache_stats is None:
return None
self._mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def update_mm_cache_stats(self) -> None:
mm_processor_cache = self.mm_processor_cache
mm_cache_stats = self._mm_cache_stats
if mm_processor_cache and mm_cache_stats:
delta = mm_processor_cache.make_stats(delta=True)
mm_cache_stats.record(delta.total, delta.hits)
def clear_mm_cache(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.clear_cache()
if self._mm_cache_stats is not None:
self._mm_cache_stats.reset = True
def shutdown(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.close()
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for BOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.bos_token_id
def get_eos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for EOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.eos_token_id
def get_dec_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder model,
raising an error if it is not available.
"""
dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None
)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
return dec_start_token_id
@cached_property
def default_cmpl_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_chat_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
)
# Step 1: Convert raw inputs to prompts
def render_prompt(
self,
prompt: DictPrompt | bytes,
) -> DictPrompt:
if isinstance(prompt, bytes):
embeds = safe_load_prompt_embeds(self.model_config, prompt)
prompt = EmbedsPrompt(prompt_embeds=embeds)
return prompt
def render_prompts(
self,
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
if len(prompts) == 0:
raise ValueError("You must pass at least one prompt")
return [self.render_prompt(prompt) for prompt in prompts]
async def render_prompts_async(
self,
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
return self.render_prompts(prompts)
@abstractmethod
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], DictPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], DictPrompt]:
return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary
def _tokenize_prompt(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
async def _tokenize_prompt_async(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
return prompt
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
return prompt
@overload
def _tokenize_singleton_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def _tokenize_singleton_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
def _tokenize_singleton_prompt(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
@overload
async def _tokenize_singleton_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
async def _tokenize_singleton_prompt_async(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
async def _tokenize_enc_dec_prompt_async(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self._tokenize_singleton_prompt_async(
prompt["decoder_prompt"], params
)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
def tokenize_prompt(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
return self._tokenize_singleton_prompt(prompt, params)
def tokenize_prompts(
self,
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
async def tokenize_prompt_async(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
return await self._tokenize_singleton_prompt_async(prompt, params)
async def tokenize_prompts_async(
self,
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokPrompt]:
return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
)
# Step 3: Add extra keys to the prompts
def _apply_prompt_extras(
self,
prompts: Sequence[TokPrompt],
prompt_extras: dict[str, Any] | None,
):
if not prompt_extras:
return
for prompt in prompts:
target_prompt = extract_target_prompt(self.model_config, prompt)
target_prompt.update(prompt_extras) # type: ignore[arg-type]
# Step 4: Convert to engine inputs
def _validate_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_data_items: "MultiModalDataItems",
mm_uuid_items: "MultiModalUUIDItems",
) -> None:
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in
# `mm_data_items`
modalities = mm_data.keys() | mm_uuid_items.keys()
for modality in modalities:
data_items = mm_data_items.get(modality)
uuid_items = mm_uuid_items.get(modality)
if data_items is None:
if uuid_items is None:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
elif uuid_items is not None:
if len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None and uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
def _process_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_data_items: "MultiModalDataItems",
mm_uuid_items: "MultiModalUUIDItems",
mm_req_id: str,
):
model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
if (
model_config.multimodal_config
and model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.config.cache_config.enable_prefix_caching
):
mm_uuid_items = {
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_data_items.get_all_counts().items()
}
self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
return mm_uuid_items
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
def _process_multimodal(
self,
prompt: list[int] | str,
mm_data: "MultiModalDataDict",
mm_uuids: "MultiModalUUIDDict | None",
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
mm_uuid_items = self._process_mm_uuids(
mm_data, mm_data_items, mm_uuid_items, mm_req_id
)
mm_processor_inputs = MMProcessorInputs(
prompt,
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs or {},
)
mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
with set_default_torch_num_threads():
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
self.update_mm_cache_stats()
return mm_inputs
def _process_tokens(
self,
prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs":
prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs
if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_embeds(
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = prompt["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"),
)
def _process_singleton(
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type]
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(
self, prompt: TokPrompt, arrival_time: float
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
else:
engine_prompt = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time
return engine_prompt
# Top-level methods
def render_cmpl(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
dict_prompts = self.render_prompts(prompts)
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
async def render_cmpl_async(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
dict_prompts = await self.render_prompts_async(prompts)
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
def render_chat(
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [
self.render_messages(conversation, chat_params)
for conversation in conversations
]
out_conversations = list[list["ConversationMessage"]]()
dict_prompts = list[DictPrompt]()
for conv, prompt in rendered:
out_conversations.append(conv)
dict_prompts.append(prompt)
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
async def render_chat_async(
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams | None = None,
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [
self.render_messages_async(conversation, chat_params)
for conversation in conversations
]
out_conversations = list[list["ConversationMessage"]]()
dict_prompts = list[DictPrompt]()
for conv, prompt in await asyncio.gather(*rendered):
out_conversations.append(conv)
dict_prompts.append(prompt)
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts

View File

@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
logger = init_logger(__name__)
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "DeepseekV32Renderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=DeepseekV32Tokenizer,
**tokenizer_kwargs,
)
return cls(config, tokenizer)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.model_config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.model_config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from io import BytesIO
from typing import TYPE_CHECKING
import pybase64
import torch
from vllm.exceptions import VLLMValidationError
if TYPE_CHECKING:
from vllm.config import ModelConfig
def safe_load_prompt_embeds(
model_config: "ModelConfig",
embed: bytes,
) -> torch.Tensor:
if not model_config.enable_prompt_embeds:
raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
return tensor

92
vllm/renderers/grok2.py Normal file
View File

@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
logger = init_logger(__name__)
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "Grok2Renderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=Grok2Tokenizer,
**tokenizer_kwargs,
)
return cls(config, tokenizer)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.model_config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.model_config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt

724
vllm/renderers/hf.py Normal file
View File

@@ -0,0 +1,724 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import itertools
from collections import defaultdict, deque
from collections.abc import Set
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
from vllm.config import ModelConfig, VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormat,
ChatTemplateContentFormatOption,
ChatTemplateResolutionError,
ConversationMessage,
load_chat_template,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
else:
MultiModalDataDict = dict[str, Any]
MultiModalUUIDDict = dict[str, Any]
logger = init_logger(__name__)
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.
This is needed because `lru_cache` does not cache when an exception happens.
"""
def _try_get_processor_chat_template(
tokenizer: HfTokenizer,
*,
trust_remote_code: bool,
) -> str | None:
cache_key = (tokenizer.name_or_path, trust_remote_code)
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=trust_remote_code,
)
if (
isinstance(processor, ProcessorMixin)
and hasattr(processor, "chat_template")
and (chat_template := processor.chat_template) is not None
):
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
return chat_template
except Exception:
logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
return None
def resolve_chat_template(
tokenizer: HfTokenizer,
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: "ModelConfig",
) -> str | None:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
chat_template = _try_get_processor_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
if chat_template is not None:
return chat_template
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug(
"Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=tokenizer.name_or_path,
)
if path is not None:
logger.info_once(
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path)
else:
logger.debug_once(
"There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str | None = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# Global variable that is implicitly defined at the root
yield root, varname
# Iterative BFS
related_varnames = deque([varname])
while related_varnames:
related_varname = related_varnames.popleft()
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
if _is_var_or_elems_access(rhs, related_varname):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
# Avoid infinite looping for self-assignment
if lhs.name != related_varname:
related_varnames.append(lhs.name)
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in messages_varnames:
if _is_var_or_elems_access(loop_iter, varname):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
message_varnames = [
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
# Search for {%- for content in message['content'] -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
import transformers.utils.chat_template_utils as hf_chat_utils
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return None
@lru_cache(maxsize=32)
def _detect_content_format(
chat_template: str,
*,
default: ChatTemplateContentFormat,
) -> ChatTemplateContentFormat:
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return default
try:
next(_iter_nodes_assign_content_item(jinja_ast))
except StopIteration:
return "string"
except Exception:
logger.exception("Error when parsing AST of Jinja template")
return default
else:
return "openai"
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: HfTokenizer,
*,
model_config: "ModelConfig",
) -> ChatTemplateContentFormat:
resolved_chat_template = resolve_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
jinja_text = (
resolved_chat_template
if isinstance(resolved_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = (
"string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format
@lru_cache
def _log_chat_template_content_format(
chat_template: str | None, # For caching purposes
given_format: ChatTemplateContentFormatOption,
detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
detected_format,
)
if given_format != "auto" and given_format != detected_format:
logger.warning(
"You specified `--chat-template-content-format %s` "
"which is different from the detected format '%s'. "
"If our automatic detection is incorrect, please consider "
"opening a GitHub issue so that we can improve it: "
"https://github.com/vllm-project/vllm/issues/new/choose",
given_format,
detected_format,
)
def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: HfTokenizer,
*,
model_config: "ModelConfig",
) -> ChatTemplateContentFormat:
if given_format != "auto":
return given_format
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
tokenizer,
model_config=model_config,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node:
lineno = next(parser.stream).lineno
body = parser.parse_statements(("name:endgeneration",), drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]:
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
return template_vars
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
from transformers import PreTrainedTokenizer
# Get standard parameters from HuggingFace's base tokenizer class.
# This dynamically extracts parameters from PreTrainedTokenizer's
# apply_chat_template method, ensuring compatibility with tokenizers
# that use **kwargs to receive standard parameters.
# Read signature from HF's base class - the single source of truth
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
return frozenset(
p.name
for p in base_sig.parameters.values()
if p.kind
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
)
def resolve_chat_template_kwargs(
tokenizer: HfTokenizer,
chat_template: str,
chat_template_kwargs: dict[str, Any],
raise_on_unexpected: bool = True,
) -> dict[str, Any]:
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template", "tokenize"}
if raise_on_unexpected and (
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
):
raise ValueError(
"Found unexpected chat template kwargs from request: "
f"{unexpected_in_kwargs}"
)
fn_kw = {
k
for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
hf_base_params = _get_hf_base_chat_template_params()
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
def safe_apply_chat_template(
model_config: "ModelConfig",
tokenizer: HfTokenizer,
conversation: list[ConversationMessage],
*,
tools: list[dict[str, Any]] | None = None,
chat_template: str | None = None,
tokenize: bool = True,
**kwargs,
) -> str | list[int]:
chat_template = resolve_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
if chat_template is None:
raise ChatTemplateResolutionError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=chat_template,
chat_template_kwargs=kwargs,
)
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=chat_template,
tokenize=tokenize,
**resolved_kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e
def rebuild_mm_uuids_from_mm_data(
mm_uuids: "MultiModalUUIDDict",
mm_data: "MultiModalDataDict",
) -> "MultiModalUUIDDict":
"""Rebuild mm_uuids after vision_chunk processing.
When videos are split into chunks, the original UUIDs need to be updated
to reflect the new UUIDs generated for each chunk.
Args:
mm_uuids: Original UUIDs dictionary
mm_data: Processed multimodal data with vision_chunk items
Returns:
Updated UUIDs dictionary with chunk UUIDs
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return mm_uuids
assert all(isinstance(item, dict) for item in vision_chunks), (
"Expected all vision_chunk items to be dicts"
)
vision_chunks = cast(list[dict[str, Any]], vision_chunks)
vision_chunk_uuids = [
uuid_val for item in vision_chunks if (uuid_val := item.get("uuid")) is not None
]
if vision_chunk_uuids:
mm_uuids = dict(mm_uuids)
mm_uuids["vision_chunk"] = vision_chunk_uuids
return mm_uuids
def build_video_prompts_from_mm_data(
mm_data: "MultiModalDataDict",
) -> list[str]:
"""Build video prompts from vision_chunk data.
Collects prompts from video chunks and groups them by video_idx.
Args:
mm_data: Processed multimodal data with vision_chunk items
Returns:
List of video prompts, one per video.
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return []
# Group chunks by video_idx
video_prompts_dict: dict[int, list[str]] = defaultdict(list)
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
if item.get("type") == "video_chunk":
video_idx = item.get("video_idx", 0)
prompt = item.get("prompt", "")
video_prompts_dict[video_idx].append(prompt)
# Build prompts in video order
video_prompts = [
"".join(video_prompts_dict[video_idx])
for video_idx in sorted(video_prompts_dict.keys())
]
return video_prompts
def replace_vision_chunk_video_placeholder(
prompt_raw: str | list[int],
mm_data: "MultiModalDataDict",
video_placeholder: str | None,
) -> str | list[int]:
# get video placehoder, replace it with runtime video-chunk prompts
if video_placeholder and isinstance(prompt_raw, str):
video_prompts = build_video_prompts_from_mm_data(mm_data)
# replace in order
prompt_raw_parts = prompt_raw.split(video_placeholder)
if len(prompt_raw_parts) == len(video_prompts) + 1:
prompt_raw = "".join(
itertools.chain.from_iterable(zip(prompt_raw_parts, video_prompts))
)
prompt_raw += prompt_raw_parts[-1]
else:
logger.warning(
"Number of video placeholders (%d) does not match "
"number of videos (%d) in the request.",
len(prompt_raw_parts) - 1,
len(video_prompts),
)
return prompt_raw
class HfRenderer(BaseRenderer[HfTokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cast(
HfTokenizer,
cached_get_tokenizer(
tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract]
**tokenizer_kwargs,
),
)
return cls(config, tokenizer)
def __init__(
self,
config: VllmConfig,
tokenizer: HfTokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self.use_unified_vision_chunk = getattr(
config.model_config.hf_config, "use_unified_vision_chunk", False
)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.model_config
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=params.chat_template,
tools=params.chat_template_kwargs.get("tools"),
given_format=params.chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
)
prompt_raw = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
**params.get_apply_chat_template_kwargs(),
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
# model which uses unified vision chunks for both images and videos.
if (
self.use_unified_vision_chunk
and mm_uuids is not None
and mm_data is not None
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placeholder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
prompt_raw = replace_vision_chunk_video_placeholder(
prompt_raw,
mm_data,
video_placeholder,
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.model_config
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=params.chat_template,
tools=params.chat_template_kwargs.get("tools"),
given_format=params.chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
)
prompt_raw = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
**params.get_apply_chat_template_kwargs(),
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
# model which uses unified vision chunks for both images and videos.
if (
self.use_unified_vision_chunk
and mm_uuids is not None
and mm_data is not None
):
# get video placeholder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
prompt_raw = replace_vision_chunk_video_placeholder(
prompt_raw,
mm_data,
video_placeholder,
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .preprocess import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
)
from .tokenize import (
DecoderOnlyTokPrompt,
DecoderTokPrompt,
EncoderDecoderTokPrompt,
EncoderTokPrompt,
SingletonTokPrompt,
TokPrompt,
)
__all__ = [
"DecoderOnlyDictPrompt",
"EncoderDictPrompt",
"DecoderDictPrompt",
"EncoderDecoderDictPrompt",
"SingletonDictPrompt",
"DictPrompt",
"DecoderOnlyTokPrompt",
"EncoderTokPrompt",
"DecoderTokPrompt",
"EncoderDecoderTokPrompt",
"SingletonTokPrompt",
"TokPrompt",
]

View File

@@ -0,0 +1,258 @@
"""
Schemas and utilites for preprocessing inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import is_list_of
if TYPE_CHECKING:
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@overload
def prompt_to_seq(
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
) -> Sequence[SingletonPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: ExplicitEncoderDecoderPrompt
| Sequence[ExplicitEncoderDecoderPrompt],
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: PromptType | Sequence[PromptType],
) -> Sequence[PromptType]: ...
def prompt_to_seq(
prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
) -> Sequence[PromptType]:
if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
):
return [prompt_or_prompts] # type: ignore[list-item]
return prompt_or_prompts # type: ignore[return-value]
def conversation_to_seq(
conversation_or_conversations: list["ChatCompletionMessageParam"]
| Sequence[list["ChatCompletionMessageParam"]],
) -> Sequence[list["ChatCompletionMessageParam"]]:
if len(conversation_or_conversations) > 0 and is_list_of(
conversation_or_conversations, dict
):
return [conversation_or_conversations] # type: ignore[list-item]
return conversation_or_conversations # type: ignore[return-value]
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
that has been standardized into a dictionary.
"""
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
that has been standardized into a dictionary.
"""
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
that has been standardized into a dictionary.
"""
class EncoderDecoderDictPrompt(TypedDict):
"""
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
that has been standardized into a dictionary.
"""
encoder_prompt: EncoderDictPrompt
decoder_prompt: DecoderDictPrompt | None
SingletonDictPrompt: TypeAlias = (
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
)
"""
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
that has been standardized into a dictionary.
"""
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
"""
A [`PromptType`][vllm.inputs.data.PromptType]
that has been standardized into a dictionary.
"""
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
"""
Parse a prompt for a decoder-only model and normalize it to a dictionary.
"""
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "encoder_prompt" in prompt:
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
if (
"prompt" in prompt
or "prompt_token_ids" in prompt
or "prompt_embeds" in prompt
):
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if (
"multi_modal_data" in prompt
or "mm_processor_kwargs" in prompt
or "multi_modal_uuids" in prompt
):
raise TypeError("Cannot pass multi-modal inputs to decoder prompt")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
"""
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else:
enc_prompt = prompt
dec_prompt = None
return EncoderDecoderDictPrompt(
encoder_prompt=_parse_enc_prompt(enc_prompt),
decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
)
def parse_model_prompt(model_config: "ModelConfig", prompt: object):
if model_config.is_encoder_decoder:
return parse_enc_dec_prompt(prompt)
return parse_dec_only_prompt(prompt)
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def extract_target_prompt(model_config: "ModelConfig", prompt: object):
return (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
def extract_prompt_components(
model_config: "ModelConfig",
prompt: PromptType | ProcessorInputs,
) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt)
return PromptComponents(
text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"),
embeds=target_prompt.get("prompt_embeds"),
)
def extract_prompt_len(
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
):
target_prompt = extract_target_prompt(model_config, prompt)
return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"),
target_prompt.get("prompt_embeds"),
)

View File

@@ -0,0 +1,57 @@
"""
Schemas and utilites for tokenization inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias, TypedDict
from vllm.inputs import EmbedsPrompt, TokensPrompt
DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt]
that has been tokenized.
"""
EncoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt]
that has been tokenized.
"""
DecoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt]
that has been tokenized.
"""
class EncoderDecoderTokPrompt(TypedDict):
"""
A
[`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt]
that has been tokenized.
"""
encoder_prompt: EncoderTokPrompt
decoder_prompt: DecoderTokPrompt | None
SingletonTokPrompt: TypeAlias = (
DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt
)
"""
A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt]
that has been tokenized.
"""
TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt
"""
A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt]
that has been tokenized.
"""

133
vllm/renderers/mistral.py Normal file
View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
logger = init_logger(__name__)
def safe_apply_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> str | list[int]:
from mistral_common.exceptions import MistralCommonException
try:
return tokenizer.apply_chat_template(messages, **kwargs)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat template"
)
raise ValueError(str(e)) from e
class MistralRenderer(BaseRenderer[MistralTokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "MistralRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=MistralTokenizer,
**tokenizer_kwargs,
)
return cls(config, tokenizer)
def __init__(
self,
config: VllmConfig,
tokenizer: MistralTokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._apply_chat_template_executor
)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.model_config,
content_format="string",
)
prompt_raw = safe_apply_chat_template(
tokenizer,
messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.model_config,
content_format="string",
)
prompt_raw = await self._apply_chat_template_async(
tokenizer,
messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt

383
vllm/renderers/params.py Normal file
View File

@@ -0,0 +1,383 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
else:
torch = LazyLoader("torch", globals(), "torch")
ChatTemplateContentFormatOption = object
logger = init_logger(__name__)
_S = TypeVar("_S", list[int], "torch.Tensor")
def merge_kwargs(
defaults: dict[str, Any] | None,
overrides: dict[str, Any] | None,
/,
*,
unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
if defaults is None:
defaults = {}
if overrides is None:
overrides = {}
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
@dataclass(frozen=True)
class ChatParams:
"""Configuration to control how to parse chat messages."""
chat_template: str | None = None
"""The chat template to apply."""
chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
"""The format of the chat template."""
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
"""The kwargs to pass to the chat template."""
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
if not default_chat_template_kwargs:
return self
return ChatParams(
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
chat_template_kwargs=merge_kwargs(
default_chat_template_kwargs,
self.chat_template_kwargs,
),
)
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.apply_chat_template`."""
return merge_kwargs(
self.chat_template_kwargs,
dict(chat_template=self.chat_template, return_dict=False),
)
@dataclass(frozen=True)
class TokenizeParams:
"""Configuration to control how prompts are tokenized."""
max_total_tokens: int | None
"""
Maximum allowed number of input + output tokens.
Usually, this refers to the model's context length.
"""
max_output_tokens: int = 0
"""Maximum requested number of output tokens."""
pad_prompt_tokens: int | None = None
"""
Number of tokens to pad to:
- `None` means no padding.
- `-1` maps to `max_input_tokens`.
"""
truncate_prompt_tokens: int | None = None
"""
Number of tokens to keep:
- `None` means no truncation.
- `-1` maps to `max_input_tokens`.
"""
do_lower_case: bool = False
"""Whether to normalize text to lower case before tokenization."""
add_special_tokens: bool = True
"""Whether to add special tokens."""
needs_detokenization: bool = False
"""
Whether the tokenized prompt needs to contain the original text.
Not to be confused with `SamplingParams.detokenize` which deals
with the output generated by the model.
"""
max_total_tokens_param: str = "max_total_tokens"
"""Override this to edit the message for validation errors."""
max_output_tokens_param: str = "max_output_tokens"
"""Override this to edit the message for validation errors."""
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
"""Override this to edit the message for validation errors."""
@property
def max_input_tokens(self) -> int | None:
"""Maximum allowed number of input tokens."""
if self.max_total_tokens is None:
return None
return self.max_total_tokens - self.max_output_tokens
def __post_init__(self) -> None:
max_total_tokens = self.max_total_tokens
max_output_tokens = self.max_output_tokens
max_input_tokens = self.max_input_tokens
truncate_prompt_tokens = self.truncate_prompt_tokens
if (
max_output_tokens is not None
and max_total_tokens is not None
and max_output_tokens > max_total_tokens
):
raise VLLMValidationError(
f"{self.max_output_tokens_param}={max_output_tokens}"
f"cannot be greater than "
f"{self.max_total_tokens_param}={max_total_tokens=}. "
f"Please request fewer output tokens.",
parameter=self.max_output_tokens_param,
value=max_output_tokens,
)
if (
max_input_tokens is not None
and truncate_prompt_tokens is not None
and truncate_prompt_tokens > max_input_tokens
):
raise VLLMValidationError(
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
f"cannot be greater than {self.max_total_tokens_param} - "
f"{self.max_output_tokens_param} = {max_input_tokens}. "
f"Please request a smaller truncation size.",
parameter=self.truncate_prompt_tokens_param,
value=truncate_prompt_tokens,
)
def with_kwargs(self, **tokenization_kwargs: Any):
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
pad_prompt_tokens = tokenization_kwargs.pop(
"pad_prompt_tokens", self.pad_prompt_tokens
)
truncate_prompt_tokens = tokenization_kwargs.pop(
"truncate_prompt_tokens", self.truncate_prompt_tokens
)
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
add_special_tokens = tokenization_kwargs.pop(
"add_special_tokens", self.add_special_tokens
)
needs_detokenization = tokenization_kwargs.pop(
"needs_detokenization", self.needs_detokenization
)
# https://huggingface.co/docs/transformers/en/pad_truncation
if padding := tokenization_kwargs.pop("padding", None):
if padding == "max_length":
pad_prompt_tokens = max_length
elif padding in (False, "do_not_pad"):
pad_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["padding"] = padding
if truncation := tokenization_kwargs.pop("truncation", None):
if truncation in (True, "longest_first"):
truncate_prompt_tokens = max_length
elif truncation in (False, "do_not_truncate"):
truncate_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["truncation"] = truncation
if tokenization_kwargs:
logger.warning(
"The following tokenization arguments are not supported "
"by vLLM Renderer and will be ignored: %s",
tokenization_kwargs,
)
max_total_tokens = self.max_total_tokens
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=(
0
if max_total_tokens is None or max_length is None
else max_total_tokens - max_length
),
pad_prompt_tokens=pad_prompt_tokens,
truncate_prompt_tokens=truncate_prompt_tokens,
do_lower_case=do_lower_case,
add_special_tokens=add_special_tokens,
needs_detokenization=needs_detokenization,
)
def get_encode_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.encode`."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
elif max_length is None and self.max_input_tokens is not None:
# This prevents tokenization from taking up more resources than necessary
# while still failing `self._token_len_check` as expected by users
max_length = self.max_input_tokens + 1
return dict(
truncation=max_length is not None,
max_length=max_length,
add_special_tokens=self.add_special_tokens,
)
def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply length checks to prompt text if necessary."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return text
if self.truncate_prompt_tokens is None and tokenizer is not None:
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
if len(text) > max_input_chars:
# To save resources, fail the request outright without even
# attempting tokenization
raise VLLMValidationError(
f"You passed {len(text)} input characters "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens "
f"(at most {max_input_chars} characters). "
f"Please reduce the length of the input prompt.",
parameter="input_text",
value=len(text),
)
return text
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply lowercase to prompt text if necessary."""
return text.lower() if self.do_lower_case else text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text."""
for validator in (
self._text_len_check,
self._text_lowercase,
):
text = validator(tokenizer, text)
return text
def apply_pre_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TextPrompt,
) -> TextPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run before tokenization occurs.
"""
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
return prompt
def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to prompt tokens if necessary."""
pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens
if pad_length is None or pad_length <= len(tokens):
return tokens
if tokenizer is None:
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
if not isinstance(tokens, list):
raise ValueError("Cannot pad tokens for embedding inputs")
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to prompt tokens if necessary."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
if max_length is None or max_length >= len(tokens):
return tokens
if max_length == 0:
return tokens[:0]
if getattr(tokenizer, "truncation_side", "left") == "left":
return tokens[-max_length:]
return tokens[:max_length]
def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to prompt tokens if necessary."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return tokens
if len(tokens) > max_input_tokens:
raise VLLMValidationError(
f"You passed {len(tokens)} input tokens "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens. "
f"Please reduce the length of the input prompt.",
parameter="input_tokens",
value=len(tokens),
)
return tokens
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence."""
for validator in (
self._token_padding,
self._token_truncation,
self._token_len_check,
):
tokens = validator(tokenizer, tokens)
return tokens
def apply_post_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TokensPrompt | EmbedsPrompt,
) -> TokensPrompt | EmbedsPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run after tokenization occurs.
"""
if "prompt_token_ids" in prompt:
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
)
if "prompt_embeds" in prompt:
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_embeds"], # type: ignore[typeddict-item]
)
return prompt

View File

@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from .base import BaseRenderer
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
_VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"mistral": ("mistral", "MistralRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"),
}
@dataclass
class RendererRegistry:
# Renderer mode -> (renderer module, renderer class)
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
if renderer_mode in self.renderers:
logger.warning(
"%s.%s is already registered for renderer_mode=%r. "
"It is overwritten by the new one.",
module,
class_name,
renderer_mode,
)
self.renderers[renderer_mode] = (module, class_name)
return None
def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
if renderer_mode not in self.renderers:
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
module, class_name = self.renderers[renderer_mode]
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
return resolve_obj_by_qualname(f"{module}.{class_name}")
def load_renderer(
self,
renderer_mode: str,
config: "VllmConfig",
tokenizer_kwargs: dict[str, Any],
) -> BaseRenderer:
renderer_cls = self.load_renderer_cls(renderer_mode)
return renderer_cls.from_config(config, tokenizer_kwargs)
RENDERER_REGISTRY = RendererRegistry(
{
mode: (f"vllm.renderers.{mod_relname}", cls_name)
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
}
)
"""The global `RendererRegistry` instance."""
def renderer_from_config(config: "VllmConfig", **kwargs):
model_config = config.model_config
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
model_config, **kwargs
)
if (
model_config.tokenizer_mode == "auto"
and model_config.model_impl == "terratorch"
):
renderer_mode = "terratorch"
else:
renderer_mode = tokenizer_mode
return RENDERER_REGISTRY.load_renderer(
renderer_mode,
config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.logger import init_logger
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
logger = init_logger(__name__)
class TerratorchRenderer(BaseRenderer):
@classmethod
def from_config(
cls,
config: VllmConfig, # type: ignore[override]
tokenizer_kwargs: dict[str, Any],
) -> "TerratorchRenderer":
model_config = config.model_config
if not model_config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
return cls(config, None)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.model_config
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
model_config,
content_format="string",
)
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.model_config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
model_config,
content_format="string",
)
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt