# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import json from abc import ABC, abstractmethod from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast) import jinja2 import jinja2.ext import jinja2.meta import jinja2.nodes import jinja2.parser import jinja2.sandbox import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, ChatCompletionContentPartImageParam, ChatCompletionContentPartInputAudioParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) from openai.types.chat import (ChatCompletionContentPartRefusalParam, ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam) from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) from openai.types.responses import ResponseInputImageParam from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter # yapf: enable from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin) # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict) from vllm.multimodal.utils import MediaConnector # yapf: disable from vllm.transformers_utils.chat_templates import ( get_chat_template_fallback_path) # yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import random_uuid, supports_kw logger = init_logger(__name__) MODALITY_PLACEHOLDERS_MAP = { "image": "<##IMAGE##>", "audio": "<##AUDIO##>", "video": "<##VIDEO##>", } class AudioURL(TypedDict, total=False): url: Required[str] """ Either a URL of the audio or a data URL with base64 encoded audio data. """ class ChatCompletionContentPartAudioParam(TypedDict, total=False): audio_url: Required[AudioURL] type: Required[Literal["audio_url"]] """The type of the content part.""" class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): image_embeds: Optional[Union[str, dict[str, str]]] """ The image embeddings. It can be either: - A single base64 string. - A dictionary where each value is a base64 string. """ type: Required[Literal["image_embeds"]] """The type of the content part.""" uuid: Optional[str] """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class VideoURL(TypedDict, total=False): url: Required[str] """ Either a URL of the video or a data URL with base64 encoded video data. """ class ChatCompletionContentPartVideoParam(TypedDict, total=False): video_url: Required[VideoURL] type: Required[Literal["video_url"]] """The type of the content part.""" class PILImage(BaseModel): """ A PIL.Image.Image object. """ image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) class CustomChatCompletionContentPILImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a PIL image. Example: { "image_pil": ImageAsset('cherry_blossom').pil_image } """ image_pil: Optional[PILImage] uuid: Optional[str] """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain image_url. This is supported by OpenAI API, although it is not documented. Example: { "image_url": "https://example.com/image.jpg" } """ image_url: Optional[str] uuid: Optional[str] """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. Example: { "audio_url": "https://example.com/audio.mp3" } """ audio_url: Optional[str] class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. Example: { "video_url": "https://example.com/video.mp4" } """ video_url: Optional[str] uuid: Optional[str] """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class CustomThinkCompletionContentParam(TypedDict, total=False): """A Think Completion Content Param that accepts a plain text and a boolean. Example: { "thinking": "I am thinking about the answer", "closed": True, "type": "thinking" } """ thinking: Required[str] """The thinking content.""" closed: bool """Whether the thinking is closed.""" type: Required[Literal["thinking"]] """The thinking type.""" ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPILImageParam, CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleVideoParam, str, CustomThinkCompletionContentParam, ] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" role: Required[str] """The role of the message's author.""" content: Union[str, list[ChatCompletionContentPartParam]] """The contents of the message.""" name: str """An optional name for the participant. Provides the model information to differentiate between participants of the same role. """ tool_call_id: Optional[str] """Tool call that this message is responding to.""" tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls.""" ChatCompletionMessageParam = Union[ OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam, OpenAIHarmonyMessage, ] # TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" content: Union[Optional[str], list[dict[str, str]]] """The contents of the message""" tool_call_id: Optional[str] """Tool call that this message is responding to.""" name: Optional[str] """The name of the function to call""" tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls.""" # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] # Used internally _ChatTemplateContentFormat = Literal["string", "openai"] def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: if isinstance(node, jinja2.nodes.Name): return node.ctx == "load" and node.name == varname return False def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: if isinstance(node, jinja2.nodes.Getitem): return (_is_var_access(node.node, varname) and isinstance(node.arg, jinja2.nodes.Const) and node.arg.value == key) if isinstance(node, jinja2.nodes.Getattr): return _is_var_access(node.node, varname) and node.attr == key return False def _is_var_or_elems_access( node: jinja2.nodes.Node, varname: str, key: Optional[str] = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): return node.node is not None and _is_var_or_elems_access( node.node, varname, key) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) if isinstance(node, jinja2.nodes.Getitem) and isinstance( node.arg, jinja2.nodes.Slice): return _is_var_or_elems_access(node.node, varname, key) # yapf: disable return ( _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) ) # yapf: enable def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): # Global variable that is implicitly defined at the root yield root, varname # Iterative BFS related_varnames = deque([varname]) while related_varnames: related_varname = related_varnames.popleft() for assign_ast in root.find_all(jinja2.nodes.Assign): lhs = assign_ast.target rhs = assign_ast.node if _is_var_or_elems_access(rhs, related_varname): assert isinstance(lhs, jinja2.nodes.Name) yield assign_ast, lhs.name # Avoid infinite looping for self-assignment if lhs.name != related_varname: related_varnames.append(lhs.name) # NOTE: The proper way to handle this is to build a CFG so that we can handle # the scope in which each variable is defined, but that is too complicated def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): messages_varnames = [ varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") ] # Search for {%- for message in messages -%} loops for loop_ast in root.find_all(jinja2.nodes.For): loop_iter = loop_ast.iter loop_target = loop_ast.target for varname in messages_varnames: if _is_var_or_elems_access(loop_iter, varname): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name break def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): message_varnames = [ varname for _, varname in _iter_nodes_assign_messages_item(root) ] # Search for {%- for content in message['content'] -%} loops for loop_ast in root.find_all(jinja2.nodes.For): loop_iter = loop_ast.iter loop_target = loop_ast.target for varname in message_varnames: if _is_var_or_elems_access(loop_iter, varname, "content"): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name break def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: try: jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) return jinja_compiled.environment.parse(chat_template) except Exception: logger.exception("Error when compiling Jinja template") return None @lru_cache(maxsize=32) def _detect_content_format( chat_template: str, *, default: _ChatTemplateContentFormat, ) -> _ChatTemplateContentFormat: jinja_ast = _try_extract_ast(chat_template) if jinja_ast is None: return default try: next(_iter_nodes_assign_content_item(jinja_ast)) except StopIteration: return "string" except Exception: logger.exception("Error when parsing AST of Jinja template") return default else: return "openai" def resolve_mistral_chat_template( chat_template: Optional[str], **kwargs: Any, ) -> Optional[str]: if chat_template is not None: logger.warning_once( "'chat_template' cannot be overridden for mistral tokenizer." ) if "add_generation_prompt" in kwargs: logger.warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " "so it will be ignored." ) if "continue_final_message" in kwargs: logger.warning_once( "'continue_final_message' is not supported for mistral tokenizer, " "so it will be ignored." ) return None _PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]() """ Used in `_try_get_processor_chat_template` to avoid calling `cached_get_processor` again if the processor fails to be loaded. This is needed because `lru_cache` does not cache when an exception happens. """ def _try_get_processor_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], model_config: ModelConfig, ) -> Optional[str]: cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) if cache_key in _PROCESSOR_CHAT_TEMPLATES: return _PROCESSOR_CHAT_TEMPLATES[cache_key] try: processor = cached_get_processor( tokenizer.name_or_path, processor_cls=( PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin, ), trust_remote_code=model_config.trust_remote_code, ) if ( isinstance(processor, ProcessorMixin) and hasattr(processor, "chat_template") and (chat_template := processor.chat_template) is not None ): _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template return chat_template except Exception: logger.debug( "Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True, ) _PROCESSOR_CHAT_TEMPLATES[cache_key] = None return None def resolve_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], *, model_config: ModelConfig, ) -> Optional[str]: # 1st priority: The given chat template if chat_template is not None: return chat_template # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: chat_template = _try_get_processor_chat_template(tokenizer, model_config) if chat_template is not None: return chat_template # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: logger.debug( "Failed to load AutoTokenizer chat template for %s", tokenizer.name_or_path, exc_info=True, ) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( model_type=model_config.hf_config.model_type, tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: logger.info( "Loading chat template fallback for %s as there isn't one " "defined on HF Hub.", tokenizer.name_or_path, ) chat_template = load_chat_template(path) else: logger.debug( "There is no chat template fallback for %s", tokenizer.name_or_path ) return chat_template def _resolve_chat_template_content_format( chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], tokenizer: AnyTokenizer, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): hf_chat_template = resolve_hf_chat_template( tokenizer, chat_template=chat_template, tools=tools, model_config=model_config, ) else: hf_chat_template = None jinja_text = ( hf_chat_template if isinstance(hf_chat_template, str) else load_chat_template(chat_template, is_literal=True) ) detected_format = ( "string" if jinja_text is None else _detect_content_format(jinja_text, default="string") ) return detected_format @lru_cache def _log_chat_template_content_format( chat_template: Optional[str], given_format: ChatTemplateContentFormatOption, detected_format: ChatTemplateContentFormatOption, ): logger.info( "Detected the chat template content format to be '%s'. " "You can set `--chat-template-content-format` to override this.", detected_format, ) if given_format != "auto" and given_format != detected_format: logger.warning( "You specified `--chat-template-content-format %s` " "which is different from the detected format '%s'. " "If our automatic detection is incorrect, please consider " "opening a GitHub issue so that we can improve it: " "https://github.com/vllm-project/vllm/issues/new/choose", given_format, detected_format, ) def resolve_chat_template_content_format( chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: if given_format != "auto": return given_format detected_format = _resolve_chat_template_content_format( chat_template, tools, tokenizer, model_config=model_config, ) _log_chat_template_content_format( chat_template, given_format=given_format, detected_format=detected_format, ) return detected_format ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number of multi-modal items in a given request does not exceed the configured maximum per prompt. """ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): super().__init__() self._model_config = model_config self._tokenizer = tokenizer self._items_by_modality = defaultdict[str, list[Optional[_T]]](list) self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list) @property def model_config(self) -> ModelConfig: return self._model_config @cached_property def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls model_cls = get_model_cls(self.model_config) return cast(type[SupportsMultiModal], model_cls) @property def allowed_local_media_path(self): return self._model_config.allowed_local_media_path @property def allowed_media_domains(self): return self._model_config.allowed_media_domains @property def mm_registry(self): return MULTIMODAL_REGISTRY @cached_property def mm_processor(self): return self.mm_registry.create_processor(self.model_config) def add( self, modality: ModalityStr, item: Optional[_T], uuid: Optional[str] = None, ) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. An optional uuid can be added which serves as a unique identifier of the media. """ input_modality = modality.replace("_embeds", "") num_items = len(self._items_by_modality[modality]) + 1 self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) self._uuids_by_modality[modality].append(uuid) return self.model_cls.get_placeholder_str(modality, num_items) def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: if not self._items_by_modality: return None mm_uuids = {} uuids_by_modality = dict(self._uuids_by_modality) if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: raise ValueError( "Mixing raw image and embedding inputs is not allowed" ) if "image_embeds" in uuids_by_modality: image_embeds_uuids = uuids_by_modality["image_embeds"] if len(image_embeds_uuids) > 1: raise ValueError( "Only one message can have {'type': 'image_embeds'}" ) mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images if "audio" in uuids_by_modality: mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios if "video" in uuids_by_modality: mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos return mm_uuids @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( "Mixing raw image and embedding inputs is not allowed" ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError( "Only one message can have {'type': 'image_embeds'}" ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = {} for modality, items in self._items_by_modality.items(): coros = [] for item in items: if item is not None: coros.append(item) else: coros.append(asyncio.sleep(0)) items_by_modality[modality] = await asyncio.gather(*coros) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( "Mixing raw image and embedding inputs is not allowed" ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError( "Only one message can have {'type': 'image_embeds'}" ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self) class BaseMultiModalContentParser(ABC): def __init__(self) -> None: super().__init__() # stores model placeholders list with corresponding # general MM placeholder: # { # "<##IMAGE##>": ["", "", ""], # "<##AUDIO##>": ["