update
This commit is contained in:
15
vllm/renderers/__init__.py
Normal file
15
vllm/renderers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .params import ChatParams, TokenizeParams, merge_kwargs
|
||||
from .registry import RendererRegistry, renderer_from_config
|
||||
|
||||
__all__ = [
|
||||
"BaseRenderer",
|
||||
"RendererRegistry",
|
||||
"renderer_from_config",
|
||||
"ChatParams",
|
||||
"TokenizeParams",
|
||||
"merge_kwargs",
|
||||
]
|
||||
767
vllm/renderers/base.py
Normal file
767
vllm/renderers/base.py
Normal file
@@ -0,0 +1,767 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Generic, overload
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.inputs import (
|
||||
EmbedsInputs,
|
||||
EmbedsPrompt,
|
||||
EncoderDecoderInputs,
|
||||
ProcessorInputs,
|
||||
SingletonInputs,
|
||||
TextPrompt,
|
||||
TokenInputs,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.counter import AtomicCounter
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
|
||||
from .embed_utils import safe_load_prompt_embeds
|
||||
from .inputs import (
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDecoderTokPrompt,
|
||||
SingletonDictPrompt,
|
||||
SingletonTokPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
from .inputs.preprocess import extract_target_prompt
|
||||
from .params import ChatParams, TokenizeParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalInputs,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
|
||||
|
||||
|
||||
class BaseRenderer(ABC, Generic[_T]):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "BaseRenderer":
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.model_config = config.model_config
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Lazy initialization since offline LLM doesn't use async
|
||||
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||
|
||||
self.mm_processor: BaseMultiModalProcessor | None = None
|
||||
self._mm_cache_stats: MultiModalCacheStats | None = None
|
||||
if config.model_config.is_multimodal_model:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
|
||||
from vllm.multimodal.registry import MultiModalTimingRegistry
|
||||
|
||||
mm_processor_cache = mm_registry.processor_cache_from_config(config)
|
||||
|
||||
with set_default_torch_num_threads():
|
||||
self.mm_processor = mm_registry.create_processor(
|
||||
config.model_config,
|
||||
tokenizer=tokenizer,
|
||||
cache=mm_processor_cache,
|
||||
)
|
||||
|
||||
if mm_processor_cache:
|
||||
self._mm_cache_stats = MultiModalCacheStats()
|
||||
|
||||
# This is used to generate internal request ID for MM processing
|
||||
# It has no relation to the request ID for engine core
|
||||
self._mm_req_counter = AtomicCounter()
|
||||
self._mm_timing_registry = MultiModalTimingRegistry(
|
||||
config.observability_config
|
||||
)
|
||||
|
||||
def get_tokenizer(self) -> _T:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
if self._async_tokenizer is None:
|
||||
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
||||
|
||||
return self._async_tokenizer
|
||||
|
||||
def get_mm_processor(self) -> "BaseMultiModalProcessor":
|
||||
if self.mm_processor is None:
|
||||
raise ValueError("Multi-modal processor not available for text-only models")
|
||||
|
||||
return self.mm_processor
|
||||
|
||||
@property
|
||||
def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
|
||||
if self.mm_processor is None:
|
||||
return None
|
||||
|
||||
return self.mm_processor.cache
|
||||
|
||||
def stat_mm_cache(self) -> MultiModalCacheStats | None:
|
||||
mm_cache_stats = self._mm_cache_stats
|
||||
if mm_cache_stats is None:
|
||||
return None
|
||||
|
||||
self._mm_cache_stats = MultiModalCacheStats()
|
||||
|
||||
return mm_cache_stats
|
||||
|
||||
def update_mm_cache_stats(self) -> None:
|
||||
mm_processor_cache = self.mm_processor_cache
|
||||
mm_cache_stats = self._mm_cache_stats
|
||||
|
||||
if mm_processor_cache and mm_cache_stats:
|
||||
delta = mm_processor_cache.make_stats(delta=True)
|
||||
mm_cache_stats.record(delta.total, delta.hits)
|
||||
|
||||
def clear_mm_cache(self) -> None:
|
||||
mm_processor_cache = self.mm_processor_cache
|
||||
if mm_processor_cache is not None:
|
||||
mm_processor_cache.clear_cache()
|
||||
|
||||
if self._mm_cache_stats is not None:
|
||||
self._mm_cache_stats.reset = True
|
||||
|
||||
def shutdown(self) -> None:
|
||||
mm_processor_cache = self.mm_processor_cache
|
||||
if mm_processor_cache is not None:
|
||||
mm_processor_cache.close()
|
||||
|
||||
def get_bos_token_id(self) -> int | None:
|
||||
if self.tokenizer is None:
|
||||
logger.warning_once(
|
||||
"Using None for BOS token id because tokenizer is not initialized"
|
||||
)
|
||||
return None
|
||||
|
||||
return self.tokenizer.bos_token_id
|
||||
|
||||
def get_eos_token_id(self) -> int | None:
|
||||
if self.tokenizer is None:
|
||||
logger.warning_once(
|
||||
"Using None for EOS token id because tokenizer is not initialized"
|
||||
)
|
||||
return None
|
||||
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
def get_dec_start_token_id(self) -> int:
|
||||
"""
|
||||
Obtain the decoder start token id employed by an encoder/decoder model,
|
||||
raising an error if it is not available.
|
||||
"""
|
||||
dec_start_token_id = getattr(
|
||||
self.model_config.hf_config, "decoder_start_token_id", None
|
||||
)
|
||||
|
||||
if dec_start_token_id is None:
|
||||
logger.warning_once(
|
||||
"Falling back on <BOS> for decoder start token id "
|
||||
"because decoder start token id is not available."
|
||||
)
|
||||
dec_start_token_id = self.get_bos_token_id()
|
||||
|
||||
if dec_start_token_id is None:
|
||||
raise RuntimeError("Cannot find decoder start token id or <BOS>")
|
||||
|
||||
return dec_start_token_id
|
||||
|
||||
@cached_property
|
||||
def default_cmpl_tok_params(self) -> TokenizeParams:
|
||||
mm_processor = self.mm_processor
|
||||
if mm_processor is not None:
|
||||
return mm_processor.info.default_tok_params
|
||||
|
||||
model_config = self.model_config
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=True,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def default_chat_tok_params(self) -> TokenizeParams:
|
||||
mm_processor = self.mm_processor
|
||||
if mm_processor is not None:
|
||||
return mm_processor.info.default_tok_params
|
||||
|
||||
model_config = self.model_config
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
# Step 1: Convert raw inputs to prompts
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt: DictPrompt | bytes,
|
||||
) -> DictPrompt:
|
||||
if isinstance(prompt, bytes):
|
||||
embeds = safe_load_prompt_embeds(self.model_config, prompt)
|
||||
prompt = EmbedsPrompt(prompt_embeds=embeds)
|
||||
|
||||
return prompt
|
||||
|
||||
def render_prompts(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | bytes],
|
||||
) -> list[DictPrompt]:
|
||||
if len(prompts) == 0:
|
||||
raise ValueError("You must pass at least one prompt")
|
||||
|
||||
return [self.render_prompt(prompt) for prompt in prompts]
|
||||
|
||||
async def render_prompts_async(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | bytes],
|
||||
) -> list[DictPrompt]:
|
||||
return self.render_prompts(prompts)
|
||||
|
||||
@abstractmethod
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
params: ChatParams,
|
||||
) -> tuple[list["ConversationMessage"], DictPrompt]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
params: ChatParams,
|
||||
) -> tuple[list["ConversationMessage"], DictPrompt]:
|
||||
return self.render_messages(messages, params)
|
||||
|
||||
# Step 2: Tokenize prompts if necessary
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: TextPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt:
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt:
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt_token_ids = await tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
|
||||
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
|
||||
|
||||
return prompt
|
||||
|
||||
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
|
||||
|
||||
return prompt
|
||||
|
||||
@overload
|
||||
def _tokenize_singleton_prompt(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
def _tokenize_singleton_prompt( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
def _tokenize_singleton_prompt(
|
||||
self,
|
||||
prompt: SingletonDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> SingletonTokPrompt:
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
prompt = self._tokenize_prompt(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
@overload
|
||||
async def _tokenize_singleton_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
async def _tokenize_singleton_prompt_async(
|
||||
self,
|
||||
prompt: SingletonDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> SingletonTokPrompt:
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
prompt = await self._tokenize_prompt_async(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
def _tokenize_enc_dec_prompt(
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt:
|
||||
enc_prompt, dec_prompt = (
|
||||
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
|
||||
(
|
||||
None
|
||||
if prompt["decoder_prompt"] is None
|
||||
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
|
||||
),
|
||||
)
|
||||
|
||||
return EncoderDecoderTokPrompt(
|
||||
encoder_prompt=enc_prompt,
|
||||
decoder_prompt=dec_prompt,
|
||||
)
|
||||
|
||||
async def _tokenize_enc_dec_prompt_async(
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt:
|
||||
enc_prompt, dec_prompt = await asyncio.gather(
|
||||
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
|
||||
(
|
||||
asyncio.sleep(0)
|
||||
if prompt["decoder_prompt"] is None
|
||||
else self._tokenize_singleton_prompt_async(
|
||||
prompt["decoder_prompt"], params
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return EncoderDecoderTokPrompt(
|
||||
encoder_prompt=enc_prompt,
|
||||
decoder_prompt=dec_prompt,
|
||||
)
|
||||
|
||||
def tokenize_prompt(
|
||||
self,
|
||||
prompt: DictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokPrompt:
|
||||
if "encoder_prompt" in prompt:
|
||||
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
|
||||
|
||||
return self._tokenize_singleton_prompt(prompt, params)
|
||||
|
||||
def tokenize_prompts(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt],
|
||||
params: TokenizeParams,
|
||||
) -> list[TokPrompt]:
|
||||
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
|
||||
|
||||
async def tokenize_prompt_async(
|
||||
self,
|
||||
prompt: DictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokPrompt:
|
||||
if "encoder_prompt" in prompt:
|
||||
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
|
||||
|
||||
return await self._tokenize_singleton_prompt_async(prompt, params)
|
||||
|
||||
async def tokenize_prompts_async(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt],
|
||||
params: TokenizeParams,
|
||||
) -> list[TokPrompt]:
|
||||
return await asyncio.gather(
|
||||
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
|
||||
)
|
||||
|
||||
# Step 3: Add extra keys to the prompts
|
||||
def _apply_prompt_extras(
|
||||
self,
|
||||
prompts: Sequence[TokPrompt],
|
||||
prompt_extras: dict[str, Any] | None,
|
||||
):
|
||||
if not prompt_extras:
|
||||
return
|
||||
|
||||
for prompt in prompts:
|
||||
target_prompt = extract_target_prompt(self.model_config, prompt)
|
||||
target_prompt.update(prompt_extras) # type: ignore[arg-type]
|
||||
|
||||
# Step 4: Convert to engine inputs
|
||||
def _validate_mm_uuids(
|
||||
self,
|
||||
mm_data: "MultiModalDataDict",
|
||||
mm_data_items: "MultiModalDataItems",
|
||||
mm_uuid_items: "MultiModalUUIDItems",
|
||||
) -> None:
|
||||
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in
|
||||
# `mm_data_items`
|
||||
modalities = mm_data.keys() | mm_uuid_items.keys()
|
||||
|
||||
for modality in modalities:
|
||||
data_items = mm_data_items.get(modality)
|
||||
uuid_items = mm_uuid_items.get(modality)
|
||||
|
||||
if data_items is None:
|
||||
if uuid_items is None:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
elif uuid_items is not None:
|
||||
if len(data_items) != len(uuid_items):
|
||||
raise ValueError(
|
||||
f"If given, multi_modal_uuids[{modality!r}] must have "
|
||||
f"same length as multi_modal_data[{modality!r}], but "
|
||||
f"got {len(uuid_items)} vs {len(data_items)}."
|
||||
)
|
||||
|
||||
for i, item in enumerate(data_items):
|
||||
if item is None and uuid_items[i] is None:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}][{i}] is missing."
|
||||
)
|
||||
|
||||
def _process_mm_uuids(
|
||||
self,
|
||||
mm_data: "MultiModalDataDict",
|
||||
mm_data_items: "MultiModalDataItems",
|
||||
mm_uuid_items: "MultiModalUUIDItems",
|
||||
mm_req_id: str,
|
||||
):
|
||||
model_config = self.model_config
|
||||
|
||||
# NOTE: When users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
|
||||
if (
|
||||
model_config.multimodal_config
|
||||
and model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.config.cache_config.enable_prefix_caching
|
||||
):
|
||||
mm_uuid_items = {
|
||||
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
|
||||
for modality, data_count in mm_data_items.get_all_counts().items()
|
||||
}
|
||||
|
||||
self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
|
||||
|
||||
return mm_uuid_items
|
||||
|
||||
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
|
||||
def _process_multimodal(
|
||||
self,
|
||||
prompt: list[int] | str,
|
||||
mm_data: "MultiModalDataDict",
|
||||
mm_uuids: "MultiModalUUIDDict | None",
|
||||
mm_processor_kwargs: Mapping[str, object] | None,
|
||||
tokenization_kwargs: dict[str, Any] | None,
|
||||
) -> "MultiModalInputs":
|
||||
from vllm.multimodal.parse import parse_mm_uuids
|
||||
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
|
||||
|
||||
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
|
||||
|
||||
mm_processor = self.get_mm_processor()
|
||||
|
||||
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
mm_uuid_items = parse_mm_uuids(mm_uuids)
|
||||
|
||||
mm_uuid_items = self._process_mm_uuids(
|
||||
mm_data, mm_data_items, mm_uuid_items, mm_req_id
|
||||
)
|
||||
|
||||
mm_processor_inputs = MMProcessorInputs(
|
||||
prompt,
|
||||
mm_data_items,
|
||||
mm_uuid_items,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs or {},
|
||||
tokenization_kwargs=tokenization_kwargs or {},
|
||||
)
|
||||
mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
|
||||
|
||||
with set_default_torch_num_threads():
|
||||
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
|
||||
|
||||
self.update_mm_cache_stats()
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def _process_tokens(
|
||||
self,
|
||||
prompt: TokensPrompt,
|
||||
) -> "TokenInputs | MultiModalInputs":
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
|
||||
inputs: TokenInputs | MultiModalInputs
|
||||
if multi_modal_data := prompt.get("multi_modal_data"):
|
||||
inputs = self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=None, # Tokenization already done in Step 2
|
||||
mm_uuids=prompt.get("multi_modal_uuids"),
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(prompt_token_ids)
|
||||
|
||||
if prompt_text := prompt.get("prompt"):
|
||||
inputs["prompt"] = prompt_text
|
||||
if cache_salt := prompt.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
def _process_embeds(
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
|
||||
)
|
||||
|
||||
prompt_embeds = prompt["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
# Tensors must be on CPU for serialization between processes
|
||||
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
||||
# hidden device transfer in the critical path of generation.
|
||||
prompt_embeds = prompt_embeds.cpu()
|
||||
|
||||
return embeds_inputs(
|
||||
prompt_embeds=prompt_embeds,
|
||||
cache_salt=prompt.get("cache_salt"),
|
||||
)
|
||||
|
||||
def _process_singleton(
|
||||
self,
|
||||
prompt: SingletonTokPrompt,
|
||||
) -> SingletonInputs:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return self._process_tokens(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _process_enc_dec(
|
||||
self,
|
||||
prompt: EncoderDecoderTokPrompt,
|
||||
) -> EncoderDecoderInputs:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
return build_enc_dec_inputs(
|
||||
encoder_inputs=self._process_singleton(enc_prompt),
|
||||
decoder_inputs=(
|
||||
None if dec_prompt is None else self._process_singleton(dec_prompt)
|
||||
),
|
||||
decoder_start_token_id=self.get_dec_start_token_id(),
|
||||
)
|
||||
|
||||
def process_for_engine(
|
||||
self, prompt: TokPrompt, arrival_time: float
|
||||
) -> ProcessorInputs:
|
||||
engine_prompt: ProcessorInputs
|
||||
if "encoder_prompt" in prompt:
|
||||
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
|
||||
else:
|
||||
engine_prompt = self._process_singleton(prompt)
|
||||
|
||||
engine_prompt["arrival_time"] = arrival_time
|
||||
|
||||
return engine_prompt
|
||||
|
||||
# Top-level methods
|
||||
def render_cmpl(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | bytes],
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_cmpl_tok_params
|
||||
|
||||
dict_prompts = self.render_prompts(prompts)
|
||||
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
||||
|
||||
async def render_cmpl_async(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | bytes],
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_cmpl_tok_params
|
||||
|
||||
dict_prompts = await self.render_prompts_async(prompts)
|
||||
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
||||
|
||||
def render_chat(
|
||||
self,
|
||||
conversations: Sequence[list["ChatCompletionMessageParam"]],
|
||||
chat_params: ChatParams,
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_chat_tok_params
|
||||
|
||||
rendered = [
|
||||
self.render_messages(conversation, chat_params)
|
||||
for conversation in conversations
|
||||
]
|
||||
|
||||
out_conversations = list[list["ConversationMessage"]]()
|
||||
dict_prompts = list[DictPrompt]()
|
||||
for conv, prompt in rendered:
|
||||
out_conversations.append(conv)
|
||||
dict_prompts.append(prompt)
|
||||
|
||||
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
eng_prompts = [
|
||||
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
||||
]
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
|
||||
async def render_chat_async(
|
||||
self,
|
||||
conversations: Sequence[list["ChatCompletionMessageParam"]],
|
||||
chat_params: ChatParams,
|
||||
tok_params: TokenizeParams | None = None,
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_chat_tok_params
|
||||
|
||||
rendered = [
|
||||
self.render_messages_async(conversation, chat_params)
|
||||
for conversation in conversations
|
||||
]
|
||||
|
||||
out_conversations = list[list["ConversationMessage"]]()
|
||||
dict_prompts = list[DictPrompt]()
|
||||
for conv, prompt in await asyncio.gather(*rendered):
|
||||
out_conversations.append(conv)
|
||||
dict_prompts.append(prompt)
|
||||
|
||||
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
eng_prompts = [
|
||||
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
||||
]
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
92
vllm/renderers/deepseek_v32.py
Normal file
92
vllm/renderers/deepseek_v32.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "DeepseekV32Renderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=DeepseekV32Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
44
vllm/renderers/embed_utils.py
Normal file
44
vllm/renderers/embed_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
def safe_load_prompt_embeds(
|
||||
model_config: "ModelConfig",
|
||||
embed: bytes,
|
||||
) -> torch.Tensor:
|
||||
if not model_config.enable_prompt_embeds:
|
||||
raise VLLMValidationError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
|
||||
parameter="prompt_embeds",
|
||||
)
|
||||
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(
|
||||
BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
|
||||
return tensor
|
||||
92
vllm/renderers/grok2.py
Normal file
92
vllm/renderers/grok2.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "Grok2Renderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=Grok2Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
724
vllm/renderers/hf.py
Normal file
724
vllm/renderers/hf.py
Normal file
@@ -0,0 +1,724 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
import itertools
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Set
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormat,
|
||||
ChatTemplateContentFormatOption,
|
||||
ChatTemplateResolutionError,
|
||||
ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
else:
|
||||
MultiModalDataDict = dict[str, Any]
|
||||
MultiModalUUIDDict = dict[str, Any]
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
|
||||
"""
|
||||
Used in `_try_get_processor_chat_template` to avoid calling
|
||||
`cached_get_processor` again if the processor fails to be loaded.
|
||||
|
||||
This is needed because `lru_cache` does not cache when an exception happens.
|
||||
"""
|
||||
|
||||
|
||||
def _try_get_processor_chat_template(
|
||||
tokenizer: HfTokenizer,
|
||||
*,
|
||||
trust_remote_code: bool,
|
||||
) -> str | None:
|
||||
cache_key = (tokenizer.name_or_path, trust_remote_code)
|
||||
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
try:
|
||||
processor = cached_get_processor(
|
||||
tokenizer.name_or_path,
|
||||
processor_cls=(
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
),
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if (
|
||||
isinstance(processor, ProcessorMixin)
|
||||
and hasattr(processor, "chat_template")
|
||||
and (chat_template := processor.chat_template) is not None
|
||||
):
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
|
||||
return chat_template
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoProcessor chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
|
||||
return None
|
||||
|
||||
|
||||
def resolve_chat_template(
|
||||
tokenizer: HfTokenizer,
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: "ModelConfig",
|
||||
) -> str | None:
|
||||
# 1st priority: The given chat template
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
chat_template = _try_get_processor_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 3rd priority: AutoTokenizer chat template
|
||||
try:
|
||||
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoTokenizer chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 4th priority: Predefined fallbacks
|
||||
path = get_chat_template_fallback_path(
|
||||
model_type=model_config.hf_config.model_type,
|
||||
tokenizer_name_or_path=tokenizer.name_or_path,
|
||||
)
|
||||
if path is not None:
|
||||
logger.info_once(
|
||||
"Loading chat template fallback for %s as there isn't one "
|
||||
"defined on HF Hub.",
|
||||
tokenizer.name_or_path,
|
||||
)
|
||||
chat_template = load_chat_template(path)
|
||||
else:
|
||||
logger.debug_once(
|
||||
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||
)
|
||||
|
||||
return chat_template
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: str | None = None,
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||
|
||||
|
||||
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||
# Global variable that is implicitly defined at the root
|
||||
yield root, varname
|
||||
|
||||
# Iterative BFS
|
||||
related_varnames = deque([varname])
|
||||
while related_varnames:
|
||||
related_varname = related_varnames.popleft()
|
||||
|
||||
for assign_ast in root.find_all(jinja2.nodes.Assign):
|
||||
lhs = assign_ast.target
|
||||
rhs = assign_ast.node
|
||||
|
||||
if _is_var_or_elems_access(rhs, related_varname):
|
||||
assert isinstance(lhs, jinja2.nodes.Name)
|
||||
yield assign_ast, lhs.name
|
||||
|
||||
# Avoid infinite looping for self-assignment
|
||||
if lhs.name != related_varname:
|
||||
related_varnames.append(lhs.name)
|
||||
|
||||
|
||||
# NOTE: The proper way to handle this is to build a CFG so that we can handle
|
||||
# the scope in which each variable is defined, but that is too complicated
|
||||
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||
messages_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
]
|
||||
|
||||
# Search for {%- for message in messages -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in messages_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||
message_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_messages_item(root)
|
||||
]
|
||||
|
||||
# Search for {%- for content in message['content'] -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in message_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname, "content"):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception:
|
||||
logger.exception("Error when compiling Jinja template")
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _detect_content_format(
|
||||
chat_template: str,
|
||||
*,
|
||||
default: ChatTemplateContentFormat,
|
||||
) -> ChatTemplateContentFormat:
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
next(_iter_nodes_assign_content_item(jinja_ast))
|
||||
except StopIteration:
|
||||
return "string"
|
||||
except Exception:
|
||||
logger.exception("Error when parsing AST of Jinja template")
|
||||
return default
|
||||
else:
|
||||
return "openai"
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: HfTokenizer,
|
||||
*,
|
||||
model_config: "ModelConfig",
|
||||
) -> ChatTemplateContentFormat:
|
||||
resolved_chat_template = resolve_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
jinja_text = (
|
||||
resolved_chat_template
|
||||
if isinstance(resolved_chat_template, str)
|
||||
else load_chat_template(chat_template, is_literal=True)
|
||||
)
|
||||
|
||||
detected_format = (
|
||||
"string"
|
||||
if jinja_text is None
|
||||
else _detect_content_format(jinja_text, default="string")
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _log_chat_template_content_format(
|
||||
chat_template: str | None, # For caching purposes
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
detected_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
logger.info(
|
||||
"Detected the chat template content format to be '%s'. "
|
||||
"You can set `--chat-template-content-format` to override this.",
|
||||
detected_format,
|
||||
)
|
||||
|
||||
if given_format != "auto" and given_format != detected_format:
|
||||
logger.warning(
|
||||
"You specified `--chat-template-content-format %s` "
|
||||
"which is different from the detected format '%s'. "
|
||||
"If our automatic detection is incorrect, please consider "
|
||||
"opening a GitHub issue so that we can improve it: "
|
||||
"https://github.com/vllm-project/vllm/issues/new/choose",
|
||||
given_format,
|
||||
detected_format,
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: HfTokenizer,
|
||||
*,
|
||||
model_config: "ModelConfig",
|
||||
) -> ChatTemplateContentFormat:
|
||||
if given_format != "auto":
|
||||
return given_format
|
||||
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_log_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format=given_format,
|
||||
detected_format=detected_format,
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(("name:endgeneration",), drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]:
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
return template_vars
|
||||
|
||||
|
||||
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_hf_base_chat_template_params() -> frozenset[str]:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Get standard parameters from HuggingFace's base tokenizer class.
|
||||
# This dynamically extracts parameters from PreTrainedTokenizer's
|
||||
# apply_chat_template method, ensuring compatibility with tokenizers
|
||||
# that use **kwargs to receive standard parameters.
|
||||
|
||||
# Read signature from HF's base class - the single source of truth
|
||||
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
|
||||
|
||||
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
|
||||
return frozenset(
|
||||
p.name
|
||||
for p in base_sig.parameters.values()
|
||||
if p.kind
|
||||
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: HfTokenizer,
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
raise_on_unexpected: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template", "tokenize"}
|
||||
if raise_on_unexpected and (
|
||||
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
"Found unexpected chat template kwargs from request: "
|
||||
f"{unexpected_in_kwargs}"
|
||||
)
|
||||
|
||||
fn_kw = {
|
||||
k
|
||||
for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
|
||||
|
||||
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
|
||||
hf_base_params = _get_hf_base_chat_template_params()
|
||||
|
||||
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
|
||||
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
|
||||
|
||||
|
||||
def safe_apply_chat_template(
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: HfTokenizer,
|
||||
conversation: list[ConversationMessage],
|
||||
*,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template: str | None = None,
|
||||
tokenize: bool = True,
|
||||
**kwargs,
|
||||
) -> str | list[int]:
|
||||
chat_template = resolve_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
if chat_template is None:
|
||||
raise ChatTemplateResolutionError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
)
|
||||
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tokenize=tokenize,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `transformers` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
def rebuild_mm_uuids_from_mm_data(
|
||||
mm_uuids: "MultiModalUUIDDict",
|
||||
mm_data: "MultiModalDataDict",
|
||||
) -> "MultiModalUUIDDict":
|
||||
"""Rebuild mm_uuids after vision_chunk processing.
|
||||
|
||||
When videos are split into chunks, the original UUIDs need to be updated
|
||||
to reflect the new UUIDs generated for each chunk.
|
||||
|
||||
Args:
|
||||
mm_uuids: Original UUIDs dictionary
|
||||
mm_data: Processed multimodal data with vision_chunk items
|
||||
|
||||
Returns:
|
||||
Updated UUIDs dictionary with chunk UUIDs
|
||||
"""
|
||||
vision_chunks = mm_data.get("vision_chunk")
|
||||
if vision_chunks is None:
|
||||
return mm_uuids
|
||||
|
||||
assert all(isinstance(item, dict) for item in vision_chunks), (
|
||||
"Expected all vision_chunk items to be dicts"
|
||||
)
|
||||
vision_chunks = cast(list[dict[str, Any]], vision_chunks)
|
||||
vision_chunk_uuids = [
|
||||
uuid_val for item in vision_chunks if (uuid_val := item.get("uuid")) is not None
|
||||
]
|
||||
|
||||
if vision_chunk_uuids:
|
||||
mm_uuids = dict(mm_uuids)
|
||||
mm_uuids["vision_chunk"] = vision_chunk_uuids
|
||||
|
||||
return mm_uuids
|
||||
|
||||
|
||||
def build_video_prompts_from_mm_data(
|
||||
mm_data: "MultiModalDataDict",
|
||||
) -> list[str]:
|
||||
"""Build video prompts from vision_chunk data.
|
||||
|
||||
Collects prompts from video chunks and groups them by video_idx.
|
||||
|
||||
Args:
|
||||
mm_data: Processed multimodal data with vision_chunk items
|
||||
|
||||
Returns:
|
||||
List of video prompts, one per video.
|
||||
"""
|
||||
vision_chunks = mm_data.get("vision_chunk")
|
||||
if vision_chunks is None:
|
||||
return []
|
||||
|
||||
# Group chunks by video_idx
|
||||
video_prompts_dict: dict[int, list[str]] = defaultdict(list)
|
||||
|
||||
for item in vision_chunks:
|
||||
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
|
||||
assert isinstance(item, dict)
|
||||
if item.get("type") == "video_chunk":
|
||||
video_idx = item.get("video_idx", 0)
|
||||
prompt = item.get("prompt", "")
|
||||
video_prompts_dict[video_idx].append(prompt)
|
||||
|
||||
# Build prompts in video order
|
||||
video_prompts = [
|
||||
"".join(video_prompts_dict[video_idx])
|
||||
for video_idx in sorted(video_prompts_dict.keys())
|
||||
]
|
||||
|
||||
return video_prompts
|
||||
|
||||
|
||||
def replace_vision_chunk_video_placeholder(
|
||||
prompt_raw: str | list[int],
|
||||
mm_data: "MultiModalDataDict",
|
||||
video_placeholder: str | None,
|
||||
) -> str | list[int]:
|
||||
# get video placehoder, replace it with runtime video-chunk prompts
|
||||
if video_placeholder and isinstance(prompt_raw, str):
|
||||
video_prompts = build_video_prompts_from_mm_data(mm_data)
|
||||
|
||||
# replace in order
|
||||
prompt_raw_parts = prompt_raw.split(video_placeholder)
|
||||
if len(prompt_raw_parts) == len(video_prompts) + 1:
|
||||
prompt_raw = "".join(
|
||||
itertools.chain.from_iterable(zip(prompt_raw_parts, video_prompts))
|
||||
)
|
||||
prompt_raw += prompt_raw_parts[-1]
|
||||
else:
|
||||
logger.warning(
|
||||
"Number of video placeholders (%d) does not match "
|
||||
"number of videos (%d) in the request.",
|
||||
len(prompt_raw_parts) - 1,
|
||||
len(video_prompts),
|
||||
)
|
||||
return prompt_raw
|
||||
|
||||
|
||||
class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "HfRenderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cast(
|
||||
HfTokenizer,
|
||||
cached_get_tokenizer(
|
||||
tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract]
|
||||
**tokenizer_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
tokenizer: HfTokenizer | None,
|
||||
) -> None:
|
||||
super().__init__(config, tokenizer)
|
||||
|
||||
self.use_unified_vision_chunk = getattr(
|
||||
config.model_config.hf_config, "use_unified_vision_chunk", False
|
||||
)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolve_chat_template_content_format(
|
||||
chat_template=params.chat_template,
|
||||
tools=params.chat_template_kwargs.get("tools"),
|
||||
given_format=params.chat_template_content_format,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
|
||||
# model which uses unified vision chunks for both images and videos.
|
||||
if (
|
||||
self.use_unified_vision_chunk
|
||||
and mm_uuids is not None
|
||||
and mm_data is not None
|
||||
):
|
||||
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
|
||||
|
||||
# get video placeholder, replace it with runtime video-chunk prompts
|
||||
video_placeholder = getattr(
|
||||
model_config.hf_config, "video_placeholder", None
|
||||
)
|
||||
prompt_raw = replace_vision_chunk_video_placeholder(
|
||||
prompt_raw,
|
||||
mm_data,
|
||||
video_placeholder,
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolve_chat_template_content_format(
|
||||
chat_template=params.chat_template,
|
||||
tools=params.chat_template_kwargs.get("tools"),
|
||||
given_format=params.chat_template_content_format,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
|
||||
# model which uses unified vision chunks for both images and videos.
|
||||
if (
|
||||
self.use_unified_vision_chunk
|
||||
and mm_uuids is not None
|
||||
and mm_data is not None
|
||||
):
|
||||
# get video placeholder, replace it with runtime video-chunk prompts
|
||||
video_placeholder = getattr(
|
||||
model_config.hf_config, "video_placeholder", None
|
||||
)
|
||||
prompt_raw = replace_vision_chunk_video_placeholder(
|
||||
prompt_raw,
|
||||
mm_data,
|
||||
video_placeholder,
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
33
vllm/renderers/inputs/__init__.py
Normal file
33
vllm/renderers/inputs/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .preprocess import (
|
||||
DecoderDictPrompt,
|
||||
DecoderOnlyDictPrompt,
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDictPrompt,
|
||||
SingletonDictPrompt,
|
||||
)
|
||||
from .tokenize import (
|
||||
DecoderOnlyTokPrompt,
|
||||
DecoderTokPrompt,
|
||||
EncoderDecoderTokPrompt,
|
||||
EncoderTokPrompt,
|
||||
SingletonTokPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DecoderOnlyDictPrompt",
|
||||
"EncoderDictPrompt",
|
||||
"DecoderDictPrompt",
|
||||
"EncoderDecoderDictPrompt",
|
||||
"SingletonDictPrompt",
|
||||
"DictPrompt",
|
||||
"DecoderOnlyTokPrompt",
|
||||
"EncoderTokPrompt",
|
||||
"DecoderTokPrompt",
|
||||
"EncoderDecoderTokPrompt",
|
||||
"SingletonTokPrompt",
|
||||
"TokPrompt",
|
||||
]
|
||||
258
vllm/renderers/inputs/preprocess.py
Normal file
258
vllm/renderers/inputs/preprocess.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Schemas and utilites for preprocessing inputs.
|
||||
"""
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
|
||||
|
||||
from vllm.inputs import (
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
@overload
|
||||
def prompt_to_seq(
|
||||
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
|
||||
) -> Sequence[SingletonPrompt]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def prompt_to_seq( # type: ignore[misc]
|
||||
prompt_or_prompts: ExplicitEncoderDecoderPrompt
|
||||
| Sequence[ExplicitEncoderDecoderPrompt],
|
||||
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def prompt_to_seq( # type: ignore[misc]
|
||||
prompt_or_prompts: PromptType | Sequence[PromptType],
|
||||
) -> Sequence[PromptType]: ...
|
||||
|
||||
|
||||
def prompt_to_seq(
|
||||
prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
|
||||
) -> Sequence[PromptType]:
|
||||
if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
|
||||
len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
|
||||
):
|
||||
return [prompt_or_prompts] # type: ignore[list-item]
|
||||
|
||||
return prompt_or_prompts # type: ignore[return-value]
|
||||
|
||||
|
||||
def conversation_to_seq(
|
||||
conversation_or_conversations: list["ChatCompletionMessageParam"]
|
||||
| Sequence[list["ChatCompletionMessageParam"]],
|
||||
) -> Sequence[list["ChatCompletionMessageParam"]]:
|
||||
if len(conversation_or_conversations) > 0 and is_list_of(
|
||||
conversation_or_conversations, dict
|
||||
):
|
||||
return [conversation_or_conversations] # type: ignore[list-item]
|
||||
|
||||
return conversation_or_conversations # type: ignore[return-value]
|
||||
|
||||
|
||||
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
|
||||
"""
|
||||
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
|
||||
"""
|
||||
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
|
||||
"""
|
||||
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderDictPrompt(TypedDict):
|
||||
"""
|
||||
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
encoder_prompt: EncoderDictPrompt
|
||||
|
||||
decoder_prompt: DecoderDictPrompt | None
|
||||
|
||||
|
||||
SingletonDictPrompt: TypeAlias = (
|
||||
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
|
||||
)
|
||||
"""
|
||||
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
|
||||
"""
|
||||
A [`PromptType`][vllm.inputs.data.PromptType]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
|
||||
"""
|
||||
Parse a prompt for a decoder-only model and normalize it to a dictionary.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if not is_list_of(prompt, int):
|
||||
raise TypeError("Token prompt should be a list of integers")
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt)
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
if "encoder_prompt" in prompt:
|
||||
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
|
||||
|
||||
if (
|
||||
"prompt" in prompt
|
||||
or "prompt_token_ids" in prompt
|
||||
or "prompt_embeds" in prompt
|
||||
):
|
||||
return prompt # type: ignore[return-value]
|
||||
|
||||
raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")
|
||||
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if not is_list_of(prompt, int):
|
||||
raise TypeError("Token prompt should be a list of integers")
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt)
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
if "prompt_embeds" in prompt:
|
||||
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
|
||||
|
||||
if "prompt" in prompt or "prompt_token_ids" in prompt:
|
||||
return prompt # type: ignore[return-value]
|
||||
|
||||
raise TypeError("Prompt dictionary must contain text or tokens")
|
||||
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if not is_list_of(prompt, int):
|
||||
raise TypeError("Token prompt should be a list of integers")
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt)
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
if "prompt_embeds" in prompt:
|
||||
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
|
||||
|
||||
if (
|
||||
"multi_modal_data" in prompt
|
||||
or "mm_processor_kwargs" in prompt
|
||||
or "multi_modal_uuids" in prompt
|
||||
):
|
||||
raise TypeError("Cannot pass multi-modal inputs to decoder prompt")
|
||||
|
||||
if "prompt" in prompt or "prompt_token_ids" in prompt:
|
||||
return prompt # type: ignore[return-value]
|
||||
|
||||
raise TypeError("Prompt dictionary must contain text or tokens")
|
||||
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
|
||||
"""
|
||||
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
|
||||
"""
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item]
|
||||
dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item]
|
||||
else:
|
||||
enc_prompt = prompt
|
||||
dec_prompt = None
|
||||
|
||||
return EncoderDecoderDictPrompt(
|
||||
encoder_prompt=_parse_enc_prompt(enc_prompt),
|
||||
decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
|
||||
)
|
||||
|
||||
|
||||
def parse_model_prompt(model_config: "ModelConfig", prompt: object):
|
||||
if model_config.is_encoder_decoder:
|
||||
return parse_enc_dec_prompt(prompt)
|
||||
|
||||
return parse_dec_only_prompt(prompt)
|
||||
|
||||
|
||||
class PromptComponents(NamedTuple):
|
||||
text: str | None = None
|
||||
token_ids: list[int] | None = None
|
||||
embeds: "torch.Tensor | None" = None
|
||||
|
||||
|
||||
def extract_target_prompt(model_config: "ModelConfig", prompt: object):
|
||||
return (
|
||||
parse_enc_dec_prompt(prompt)["encoder_prompt"]
|
||||
if model_config.is_encoder_decoder
|
||||
else parse_dec_only_prompt(prompt)
|
||||
)
|
||||
|
||||
|
||||
def extract_prompt_components(
|
||||
model_config: "ModelConfig",
|
||||
prompt: PromptType | ProcessorInputs,
|
||||
) -> PromptComponents:
|
||||
target_prompt = extract_target_prompt(model_config, prompt)
|
||||
|
||||
return PromptComponents(
|
||||
text=target_prompt.get("prompt"),
|
||||
token_ids=target_prompt.get("prompt_token_ids"),
|
||||
embeds=target_prompt.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
|
||||
def extract_prompt_len(
|
||||
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
|
||||
):
|
||||
target_prompt = extract_target_prompt(model_config, prompt)
|
||||
|
||||
return length_from_prompt_token_ids_or_embeds(
|
||||
target_prompt.get("prompt_token_ids"),
|
||||
target_prompt.get("prompt_embeds"),
|
||||
)
|
||||
57
vllm/renderers/inputs/tokenize.py
Normal file
57
vllm/renderers/inputs/tokenize.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Schemas and utilites for tokenization inputs.
|
||||
"""
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TypeAlias, TypedDict
|
||||
|
||||
from vllm.inputs import EmbedsPrompt, TokensPrompt
|
||||
|
||||
DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt
|
||||
"""
|
||||
A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
EncoderTokPrompt: TypeAlias = TokensPrompt
|
||||
"""
|
||||
A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
DecoderTokPrompt: TypeAlias = TokensPrompt
|
||||
"""
|
||||
A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderTokPrompt(TypedDict):
|
||||
"""
|
||||
A
|
||||
[`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
encoder_prompt: EncoderTokPrompt
|
||||
|
||||
decoder_prompt: DecoderTokPrompt | None
|
||||
|
||||
|
||||
SingletonTokPrompt: TypeAlias = (
|
||||
DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt
|
||||
)
|
||||
"""
|
||||
A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt
|
||||
"""
|
||||
A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
133
vllm/renderers/mistral.py
Normal file
133
vllm/renderers/mistral.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def safe_apply_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> str | list[int]:
|
||||
from mistral_common.exceptions import MistralCommonException
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(messages, **kwargs)
|
||||
# mistral-common uses assert statements to stop processing of input
|
||||
# if input does not comply with the expected format.
|
||||
# We convert those assertion errors to ValueErrors so they can be
|
||||
# properly caught in the preprocessing_input step
|
||||
except (AssertionError, MistralCommonException) as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `mistral_common` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class MistralRenderer(BaseRenderer[MistralTokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "MistralRenderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=MistralTokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
tokenizer: MistralTokenizer | None,
|
||||
) -> None:
|
||||
super().__init__(config, tokenizer)
|
||||
|
||||
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._apply_chat_template_async = make_async(
|
||||
safe_apply_chat_template, executor=self._apply_chat_template_executor
|
||||
)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
tokenizer,
|
||||
messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
tokenizer,
|
||||
messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
383
vllm/renderers/params.py
Normal file
383
vllm/renderers/params.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
ChatTemplateContentFormatOption = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_S = TypeVar("_S", list[int], "torch.Tensor")
|
||||
|
||||
|
||||
def merge_kwargs(
|
||||
defaults: dict[str, Any] | None,
|
||||
overrides: dict[str, Any] | None,
|
||||
/,
|
||||
*,
|
||||
unset_values: tuple[object, ...] = (None, "auto"),
|
||||
) -> dict[str, Any]:
|
||||
if defaults is None:
|
||||
defaults = {}
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
||||
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatParams:
|
||||
"""Configuration to control how to parse chat messages."""
|
||||
|
||||
chat_template: str | None = None
|
||||
"""The chat template to apply."""
|
||||
|
||||
chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
|
||||
"""The format of the chat template."""
|
||||
|
||||
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""The kwargs to pass to the chat template."""
|
||||
|
||||
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
|
||||
if not default_chat_template_kwargs:
|
||||
return self
|
||||
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
default_chat_template_kwargs,
|
||||
self.chat_template_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
|
||||
"""The arguments to pass to `tokenizer.apply_chat_template`."""
|
||||
return merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(chat_template=self.chat_template, return_dict=False),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenizeParams:
|
||||
"""Configuration to control how prompts are tokenized."""
|
||||
|
||||
max_total_tokens: int | None
|
||||
"""
|
||||
Maximum allowed number of input + output tokens.
|
||||
|
||||
Usually, this refers to the model's context length.
|
||||
"""
|
||||
|
||||
max_output_tokens: int = 0
|
||||
"""Maximum requested number of output tokens."""
|
||||
|
||||
pad_prompt_tokens: int | None = None
|
||||
"""
|
||||
Number of tokens to pad to:
|
||||
- `None` means no padding.
|
||||
- `-1` maps to `max_input_tokens`.
|
||||
"""
|
||||
|
||||
truncate_prompt_tokens: int | None = None
|
||||
"""
|
||||
Number of tokens to keep:
|
||||
- `None` means no truncation.
|
||||
- `-1` maps to `max_input_tokens`.
|
||||
"""
|
||||
|
||||
do_lower_case: bool = False
|
||||
"""Whether to normalize text to lower case before tokenization."""
|
||||
|
||||
add_special_tokens: bool = True
|
||||
"""Whether to add special tokens."""
|
||||
|
||||
needs_detokenization: bool = False
|
||||
"""
|
||||
Whether the tokenized prompt needs to contain the original text.
|
||||
|
||||
Not to be confused with `SamplingParams.detokenize` which deals
|
||||
with the output generated by the model.
|
||||
"""
|
||||
|
||||
max_total_tokens_param: str = "max_total_tokens"
|
||||
"""Override this to edit the message for validation errors."""
|
||||
|
||||
max_output_tokens_param: str = "max_output_tokens"
|
||||
"""Override this to edit the message for validation errors."""
|
||||
|
||||
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
|
||||
"""Override this to edit the message for validation errors."""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int | None:
|
||||
"""Maximum allowed number of input tokens."""
|
||||
if self.max_total_tokens is None:
|
||||
return None
|
||||
|
||||
return self.max_total_tokens - self.max_output_tokens
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
max_total_tokens = self.max_total_tokens
|
||||
max_output_tokens = self.max_output_tokens
|
||||
max_input_tokens = self.max_input_tokens
|
||||
truncate_prompt_tokens = self.truncate_prompt_tokens
|
||||
|
||||
if (
|
||||
max_output_tokens is not None
|
||||
and max_total_tokens is not None
|
||||
and max_output_tokens > max_total_tokens
|
||||
):
|
||||
raise VLLMValidationError(
|
||||
f"{self.max_output_tokens_param}={max_output_tokens}"
|
||||
f"cannot be greater than "
|
||||
f"{self.max_total_tokens_param}={max_total_tokens=}. "
|
||||
f"Please request fewer output tokens.",
|
||||
parameter=self.max_output_tokens_param,
|
||||
value=max_output_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
max_input_tokens is not None
|
||||
and truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > max_input_tokens
|
||||
):
|
||||
raise VLLMValidationError(
|
||||
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
|
||||
f"cannot be greater than {self.max_total_tokens_param} - "
|
||||
f"{self.max_output_tokens_param} = {max_input_tokens}. "
|
||||
f"Please request a smaller truncation size.",
|
||||
parameter=self.truncate_prompt_tokens_param,
|
||||
value=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def with_kwargs(self, **tokenization_kwargs: Any):
|
||||
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
|
||||
pad_prompt_tokens = tokenization_kwargs.pop(
|
||||
"pad_prompt_tokens", self.pad_prompt_tokens
|
||||
)
|
||||
truncate_prompt_tokens = tokenization_kwargs.pop(
|
||||
"truncate_prompt_tokens", self.truncate_prompt_tokens
|
||||
)
|
||||
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
|
||||
add_special_tokens = tokenization_kwargs.pop(
|
||||
"add_special_tokens", self.add_special_tokens
|
||||
)
|
||||
needs_detokenization = tokenization_kwargs.pop(
|
||||
"needs_detokenization", self.needs_detokenization
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/transformers/en/pad_truncation
|
||||
if padding := tokenization_kwargs.pop("padding", None):
|
||||
if padding == "max_length":
|
||||
pad_prompt_tokens = max_length
|
||||
elif padding in (False, "do_not_pad"):
|
||||
pad_prompt_tokens = None
|
||||
else:
|
||||
# To emit the below warning
|
||||
tokenization_kwargs["padding"] = padding
|
||||
|
||||
if truncation := tokenization_kwargs.pop("truncation", None):
|
||||
if truncation in (True, "longest_first"):
|
||||
truncate_prompt_tokens = max_length
|
||||
elif truncation in (False, "do_not_truncate"):
|
||||
truncate_prompt_tokens = None
|
||||
else:
|
||||
# To emit the below warning
|
||||
tokenization_kwargs["truncation"] = truncation
|
||||
|
||||
if tokenization_kwargs:
|
||||
logger.warning(
|
||||
"The following tokenization arguments are not supported "
|
||||
"by vLLM Renderer and will be ignored: %s",
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
max_total_tokens = self.max_total_tokens
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=(
|
||||
0
|
||||
if max_total_tokens is None or max_length is None
|
||||
else max_total_tokens - max_length
|
||||
),
|
||||
pad_prompt_tokens=pad_prompt_tokens,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
do_lower_case=do_lower_case,
|
||||
add_special_tokens=add_special_tokens,
|
||||
needs_detokenization=needs_detokenization,
|
||||
)
|
||||
|
||||
def get_encode_kwargs(self) -> dict[str, Any]:
|
||||
"""The arguments to pass to `tokenizer.encode`."""
|
||||
max_length = self.truncate_prompt_tokens
|
||||
if max_length is not None and max_length < 0:
|
||||
max_length = self.max_input_tokens
|
||||
elif max_length is None and self.max_input_tokens is not None:
|
||||
# This prevents tokenization from taking up more resources than necessary
|
||||
# while still failing `self._token_len_check` as expected by users
|
||||
max_length = self.max_input_tokens + 1
|
||||
|
||||
return dict(
|
||||
truncation=max_length is not None,
|
||||
max_length=max_length,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply length checks to prompt text if necessary."""
|
||||
max_input_tokens = self.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
return text
|
||||
|
||||
if self.truncate_prompt_tokens is None and tokenizer is not None:
|
||||
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
|
||||
|
||||
if len(text) > max_input_chars:
|
||||
# To save resources, fail the request outright without even
|
||||
# attempting tokenization
|
||||
raise VLLMValidationError(
|
||||
f"You passed {len(text)} input characters "
|
||||
f"and requested {self.max_output_tokens} output tokens. "
|
||||
f"However, the model's context length is only "
|
||||
f"{self.max_total_tokens} tokens, resulting in a maximum "
|
||||
f"input length of {max_input_tokens} tokens "
|
||||
f"(at most {max_input_chars} characters). "
|
||||
f"Please reduce the length of the input prompt.",
|
||||
parameter="input_text",
|
||||
value=len(text),
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply lowercase to prompt text if necessary."""
|
||||
return text.lower() if self.do_lower_case else text
|
||||
|
||||
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply all validators to prompt text."""
|
||||
for validator in (
|
||||
self._text_len_check,
|
||||
self._text_lowercase,
|
||||
):
|
||||
text = validator(tokenizer, text)
|
||||
|
||||
return text
|
||||
|
||||
def apply_pre_tokenization(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
prompt: TextPrompt,
|
||||
) -> TextPrompt:
|
||||
"""
|
||||
Ensure that the prompt meets the requirements set out by this config.
|
||||
If that is not possible, raise a `VLLMValidationError`.
|
||||
|
||||
This method is run before tokenization occurs.
|
||||
"""
|
||||
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
|
||||
|
||||
return prompt
|
||||
|
||||
def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply padding to prompt tokens if necessary."""
|
||||
pad_length = self.pad_prompt_tokens
|
||||
if pad_length is not None and pad_length < 0:
|
||||
pad_length = self.max_input_tokens
|
||||
|
||||
if pad_length is None or pad_length <= len(tokens):
|
||||
return tokens
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
|
||||
if not isinstance(tokens, list):
|
||||
raise ValueError("Cannot pad tokens for embedding inputs")
|
||||
|
||||
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
|
||||
|
||||
def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply truncation to prompt tokens if necessary."""
|
||||
max_length = self.truncate_prompt_tokens
|
||||
if max_length is not None and max_length < 0:
|
||||
max_length = self.max_input_tokens
|
||||
|
||||
if max_length is None or max_length >= len(tokens):
|
||||
return tokens
|
||||
if max_length == 0:
|
||||
return tokens[:0]
|
||||
|
||||
if getattr(tokenizer, "truncation_side", "left") == "left":
|
||||
return tokens[-max_length:]
|
||||
|
||||
return tokens[:max_length]
|
||||
|
||||
def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply length checks to prompt tokens if necessary."""
|
||||
max_input_tokens = self.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
return tokens
|
||||
|
||||
if len(tokens) > max_input_tokens:
|
||||
raise VLLMValidationError(
|
||||
f"You passed {len(tokens)} input tokens "
|
||||
f"and requested {self.max_output_tokens} output tokens. "
|
||||
f"However, the model's context length is only "
|
||||
f"{self.max_total_tokens} tokens, resulting in a maximum "
|
||||
f"input length of {max_input_tokens} tokens. "
|
||||
f"Please reduce the length of the input prompt.",
|
||||
parameter="input_tokens",
|
||||
value=len(tokens),
|
||||
)
|
||||
|
||||
return tokens
|
||||
|
||||
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply all validators to a token sequence."""
|
||||
for validator in (
|
||||
self._token_padding,
|
||||
self._token_truncation,
|
||||
self._token_len_check,
|
||||
):
|
||||
tokens = validator(tokenizer, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def apply_post_tokenization(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
prompt: TokensPrompt | EmbedsPrompt,
|
||||
) -> TokensPrompt | EmbedsPrompt:
|
||||
"""
|
||||
Ensure that the prompt meets the requirements set out by this config.
|
||||
If that is not possible, raise a `VLLMValidationError`.
|
||||
|
||||
This method is run after tokenization occurs.
|
||||
"""
|
||||
if "prompt_token_ids" in prompt:
|
||||
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
|
||||
tokenizer,
|
||||
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
|
||||
)
|
||||
if "prompt_embeds" in prompt:
|
||||
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
|
||||
tokenizer,
|
||||
prompt["prompt_embeds"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return prompt
|
||||
92
vllm/renderers/registry.py
Normal file
92
vllm/renderers/registry.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .base import BaseRenderer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_VLLM_RENDERERS = {
|
||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
||||
"hf": ("hf", "HfRenderer"),
|
||||
"grok2": ("grok2", "Grok2Renderer"),
|
||||
"mistral": ("mistral", "MistralRenderer"),
|
||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RendererRegistry:
|
||||
# Renderer mode -> (renderer module, renderer class)
|
||||
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
|
||||
|
||||
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
|
||||
if renderer_mode in self.renderers:
|
||||
logger.warning(
|
||||
"%s.%s is already registered for renderer_mode=%r. "
|
||||
"It is overwritten by the new one.",
|
||||
module,
|
||||
class_name,
|
||||
renderer_mode,
|
||||
)
|
||||
|
||||
self.renderers[renderer_mode] = (module, class_name)
|
||||
|
||||
return None
|
||||
|
||||
def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
|
||||
if renderer_mode not in self.renderers:
|
||||
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
||||
|
||||
module, class_name = self.renderers[renderer_mode]
|
||||
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
|
||||
|
||||
return resolve_obj_by_qualname(f"{module}.{class_name}")
|
||||
|
||||
def load_renderer(
|
||||
self,
|
||||
renderer_mode: str,
|
||||
config: "VllmConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> BaseRenderer:
|
||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||
return renderer_cls.from_config(config, tokenizer_kwargs)
|
||||
|
||||
|
||||
RENDERER_REGISTRY = RendererRegistry(
|
||||
{
|
||||
mode: (f"vllm.renderers.{mod_relname}", cls_name)
|
||||
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
|
||||
}
|
||||
)
|
||||
"""The global `RendererRegistry` instance."""
|
||||
|
||||
|
||||
def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||
model_config = config.model_config
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||
model_config, **kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.tokenizer_mode == "auto"
|
||||
and model_config.model_impl == "terratorch"
|
||||
):
|
||||
renderer_mode = "terratorch"
|
||||
else:
|
||||
renderer_mode = tokenizer_mode
|
||||
|
||||
return RENDERER_REGISTRY.load_renderer(
|
||||
renderer_mode,
|
||||
config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
75
vllm/renderers/terratorch.py
Normal file
75
vllm/renderers/terratorch.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TerratorchRenderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: VllmConfig, # type: ignore[override]
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "TerratorchRenderer":
|
||||
model_config = config.model_config
|
||||
if not model_config.skip_tokenizer_init:
|
||||
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
||||
|
||||
return cls(config, None)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.model_config
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.model_config
|
||||
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
Reference in New Issue
Block a user