# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import json import time from typing import Annotated, Any, ClassVar, Literal import torch from openai.types.chat.chat_completion_audio import ( ChatCompletionAudio as OpenAIChatCompletionAudio, ) from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation from pydantic import Field, model_validator from vllm.config import ModelConfig from vllm.config.utils import replace from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, ) from vllm.entrypoints.openai.engine.protocol import ( AnyResponseFormat, DeltaMessage, FunctionCall, FunctionDefinition, LegacyStructuralTagResponseFormat, OpenAIBaseModel, StreamOptions, StructuralTagResponseFormat, ToolCall, UsageInfo, ) from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.sampling_params import ( BeamSearchParams, RepetitionDetectionParams, RequestOutputKind, SamplingParams, StructuredOutputsParams, ) from vllm.utils import random_uuid logger = init_logger(__name__) _LONG_INFO = torch.iinfo(torch.long) class ChatMessage(OpenAIBaseModel): role: str content: str | None = None refusal: str | None = None annotations: OpenAIAnnotation | None = None audio: OpenAIChatCompletionAudio | None = None function_call: FunctionCall | None = None tool_calls: list[ToolCall] = Field(default_factory=list) # vLLM-specific fields that are not in OpenAI spec reasoning: str | None = None class ChatCompletionLogProb(OpenAIBaseModel): token: str logprob: float = -9999.0 bytes: list[int] | None = None class ChatCompletionLogProbsContent(ChatCompletionLogProb): # Workaround: redefine fields name cache so that it's not # shared with the super class. field_names: ClassVar[set[str] | None] = None top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) class ChatCompletionLogProbs(OpenAIBaseModel): content: list[ChatCompletionLogProbsContent] | None = None class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: ChatCompletionLogProbs | None = None # per OpenAI spec this is the default finish_reason: str | None = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: int | str | None = None # not part of the OpenAI spec but is useful for tracing the tokens # in agent scenarios token_ids: list[int] | None = None class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None system_fingerprint: str | None = None usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: list[dict[int, Logprob] | None] | None = None prompt_token_ids: list[int] | None = None kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters." ) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: ChatCompletionLogProbs | None = None finish_reason: str | None = None stop_reason: int | str | None = None # not part of the OpenAI spec but for tracing the tokens token_ids: list[int] | None = None class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseStreamChoice] usage: UsageInfo | None = Field(default=None) # not part of the OpenAI spec but for tracing the tokens prompt_token_ids: list[int] | None = None class ChatCompletionToolsParam(OpenAIBaseModel): type: Literal["function"] = "function" function: FunctionDefinition class ChatCompletionNamedFunction(OpenAIBaseModel): name: str class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): function: ChatCompletionNamedFunction type: Literal["function"] = "function" class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create messages: list[ChatCompletionMessageParam] model: str | None = None frequency_penalty: float | None = 0.0 logit_bias: dict[str, float] | None = None logprobs: bool | None = False top_logprobs: int | None = 0 max_tokens: int | None = Field( default=None, deprecated="max_tokens is deprecated in favor of " "the max_completion_tokens field", ) max_completion_tokens: int | None = None n: int | None = 1 presence_penalty: float | None = 0.0 response_format: AnyResponseFormat | None = None seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: str | list[str] | None = [] stream: bool | None = False stream_options: StreamOptions | None = None temperature: float | None = None top_p: float | None = None tools: list[ChatCompletionToolsParam] | None = None tool_choice: ( Literal["none"] | Literal["auto"] | Literal["required"] | ChatCompletionNamedToolChoiceParam | None ) = "none" reasoning_effort: Literal["low", "medium", "high"] | None = None include_reasoning: bool = True parallel_tool_calls: bool | None = True # NOTE this will be ignored by vLLM user: str | None = None # --8<-- [start:chat-completion-sampling-params] use_beam_search: bool = False top_k: int | None = None min_p: float | None = None repetition_penalty: float | None = None length_penalty: float = 1.0 stop_token_ids: list[int] | None = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = ( None ) prompt_logprobs: int | None = None allowed_token_ids: list[int] | None = None bad_words: list[str] = Field(default_factory=list) # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] echo: bool = Field( default=False, description=( "If true, the new message will be prepended with the last message " "if they belong to the same role." ), ) add_generation_prompt: bool = Field( default=True, description=( "If true, the generation prompt will be added to the chat template. " "This is a parameter used by chat template in tokenizer config of the " "model." ), ) continue_final_message: bool = Field( default=False, description=( "If this is set, the chat will be formatted so that the final " "message in the chat is open-ended, without any EOS tokens. The " "model will continue this message rather than starting a new one. " 'This allows you to "prefill" part of the model\'s response for it. ' "Cannot be used at the same time as `add_generation_prompt`." ), ) add_special_tokens: bool = Field( default=False, description=( "If true, special tokens (e.g. BOS) will be added to the prompt " "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " "default)." ), ) documents: list[dict[str, str]] | None = Field( default=None, description=( "A list of dicts representing documents that will be accessible to " "the model if it is performing RAG (retrieval-augmented generation)." " If the template does not support RAG, this argument will have no " "effect. We recommend that each document should be a dict containing " '"title" and "text" keys.' ), ) chat_template: str | None = Field( default=None, description=( "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " "does not define one." ), ) chat_template_kwargs: dict[str, Any] | None = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " "Will be accessible by the chat template." ), ) mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) structured_outputs: StructuredOutputsParams | None = Field( default=None, description="Additional kwargs for structured outputs", ) priority: int = Field( default=0, description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling." ), ) request_id: str = Field( default_factory=random_uuid, description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response." ), ) return_tokens_as_token_ids: bool | None = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified." ), ) return_token_ids: bool | None = Field( default=None, description=( "If specified, the result will include token IDs alongside the " "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " "need to map generated text back to input tokens." ), ) cache_salt: str | None = Field( default=None, description=( "If specified, the prefix cache will be salted with the provided " "string to prevent an attacker to guess prompts in multi-user " "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " "to 256 bit)." ), ) kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, description=( "Additional request parameters with (list of) string or " "numeric values, used by custom extensions." ), ) repetition_detection: RepetitionDetectionParams | None = Field( default=None, description="Parameters for detecting repetitive N-gram patterns " "in output tokens. If such repetition is detected, generation will " "be ended early. LLMs can sometimes generate repetitive, unhelpful " "token patterns, stopping only when they hit the maximum output length " "(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature " "can detect such behavior and terminate early, saving time and tokens.", ) # --8<-- [end:chat-completion-extra-params] def build_chat_params( self, default_template: str | None, default_template_content_format: ChatTemplateContentFormatOption, ) -> ChatParams: return ChatParams( chat_template=self.chat_template or default_template, chat_template_content_format=default_template_content_format, chat_template_kwargs=merge_kwargs( self.chat_template_kwargs, dict( add_generation_prompt=self.add_generation_prompt, continue_final_message=self.continue_final_message, documents=self.documents, reasoning_effort=self.reasoning_effort, ), ), ) def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: if self.max_completion_tokens is not None: max_output_tokens: int | None = self.max_completion_tokens max_output_tokens_param = "max_completion_tokens" else: max_output_tokens = self.max_tokens max_output_tokens_param = "max_tokens" return TokenizeParams( max_total_tokens=model_config.max_model_len, max_output_tokens=max_output_tokens or 0, truncate_prompt_tokens=self.truncate_prompt_tokens, add_special_tokens=self.add_special_tokens, needs_detokenization=bool(self.echo and not self.return_token_ids), max_total_tokens_param="max_model_len", max_output_tokens_param=max_output_tokens_param, ) # Default sampling parameters for chat completion requests _DEFAULT_SAMPLING_PARAMS: dict = { "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, "top_k": 0, "min_p": 0.0, } def to_beam_search_params( self, max_tokens: int, default_sampling_params: dict ) -> BeamSearchParams: n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] ) return BeamSearchParams( beam_width=n, max_tokens=max_tokens, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, include_stop_str_in_output=self.include_stop_str_in_output, ) def to_sampling_params( self, max_tokens: int, default_sampling_params: dict, ) -> SamplingParams: # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( "repetition_penalty", self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs response_format = self.response_format if response_format is not None: structured_outputs_kwargs = dict[str, Any]() # Set structured output params for response format if response_format.type == "json_object": structured_outputs_kwargs["json_object"] = True elif response_format.type == "json_schema": json_schema = response_format.json_schema assert json_schema is not None structured_outputs_kwargs["json"] = json_schema.json_schema elif response_format.type == "structural_tag": structural_tag = response_format assert structural_tag is not None and isinstance( structural_tag, ( LegacyStructuralTagResponseFormat, StructuralTagResponseFormat, ), ) s_tag_obj = structural_tag.model_dump(by_alias=True) structured_outputs_kwargs["structural_tag"] = json.dumps(s_tag_obj) # If structured outputs wasn't already enabled, # we must enable it for these features to work if len(structured_outputs_kwargs) > 0: self.structured_outputs = ( StructuredOutputsParams(**structured_outputs_kwargs) if self.structured_outputs is None else replace(self.structured_outputs, **structured_outputs_kwargs) ) extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args extra_args["kv_transfer_params"] = self.kv_transfer_params return SamplingParams.from_optional( n=self.n, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=prompt_logprobs, ignore_eos=self.ignore_eos, max_tokens=max_tokens, min_tokens=self.min_tokens, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, output_kind=RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY, structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, bad_words=self.bad_words, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, skip_clone=True, # Created fresh per request, safe to skip clone repetition_detection=self.repetition_detection, ) @model_validator(mode="before") @classmethod def validate_response_format(cls, data): response_format = data.get("response_format") if response_format is None: return data rf_type = ( response_format.get("type") if isinstance(response_format, dict) else getattr(response_format, "type", None) ) if rf_type == "json_schema": json_schema = ( response_format.get("json_schema") if isinstance(response_format, dict) else getattr(response_format, "json_schema", None) ) if json_schema is None: raise VLLMValidationError( "When response_format type is 'json_schema', the " "'json_schema' field must be provided.", parameter="response_format", ) return data @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): raise VLLMValidationError( "Stream options can only be defined when `stream=True`.", parameter="stream_options", ) return data @model_validator(mode="before") @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise VLLMValidationError( "`prompt_logprobs` are not available when `stream=True`.", parameter="prompt_logprobs", ) if prompt_logprobs < 0 and prompt_logprobs != -1: raise VLLMValidationError( "`prompt_logprobs` must be a positive value or -1.", parameter="prompt_logprobs", value=prompt_logprobs, ) if (top_logprobs := data.get("top_logprobs")) is not None: if top_logprobs < 0 and top_logprobs != -1: raise VLLMValidationError( "`top_logprobs` must be a positive value or -1.", parameter="top_logprobs", value=top_logprobs, ) if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"): raise VLLMValidationError( "when using `top_logprobs`, `logprobs` must be set to true.", parameter="top_logprobs", ) return data @model_validator(mode="before") @classmethod def check_structured_outputs_count(cls, data): if isinstance(data, ValueError): raise data if data.get("structured_outputs", None) is None: return data structured_outputs_kwargs = data["structured_outputs"] # structured_outputs may arrive as a dict (from JSON/raw kwargs) or # as a StructuredOutputsParams dataclass instance. is_dataclass = isinstance(structured_outputs_kwargs, StructuredOutputsParams) count = sum( ( getattr(structured_outputs_kwargs, k, None) if is_dataclass else structured_outputs_kwargs.get(k) ) is not None for k in ("json", "regex", "choice") ) # you can only use one kind of constraints for structured outputs if count > 1: raise ValueError( "You can only use one kind of constraints for structured " "outputs ('json', 'regex' or 'choice')." ) # you can only either use structured outputs or tools, not both if count > 1 and data.get("tool_choice", "none") not in ( "none", "auto", "required", ): raise ValueError( "You can only either use constraints for structured outputs " "or tools, not both." ) return data @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): # if "tool_choice" is not specified but tools are provided, # default to "auto" tool_choice if "tool_choice" not in data and data.get("tools"): data["tool_choice"] = "auto" # if "tool_choice" is "none" -- no validation is needed for tools if "tool_choice" in data and data["tool_choice"] == "none": return data # if "tool_choice" is specified -- validation if "tool_choice" in data and data["tool_choice"] is not None: # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: raise ValueError("When using `tool_choice`, `tools` must be set.") # make sure that tool choice is either a named tool # OR that it's set to "auto" or "required" if data["tool_choice"] not in ["auto", "required"] and not isinstance( data["tool_choice"], dict ): raise ValueError( f"Invalid value for `tool_choice`: {data['tool_choice']}! " 'Only named tools, "none", "auto" or "required" ' "are supported." ) # if tool_choice is "required" but the "tools" list is empty, # override the data to behave like "none" to align with # OpenAI’s behavior. if ( data["tool_choice"] == "required" and isinstance(data["tools"], list) and len(data["tools"]) == 0 ): data["tool_choice"] = "none" del data["tools"] return data # ensure that if "tool_choice" is specified as an object, # it matches a valid tool correct_usage_message = ( 'Correct usage: `{"type": "function",' ' "function": {"name": "my_function"}}`' ) if isinstance(data["tool_choice"], dict): valid_tool = False function = data["tool_choice"].get("function") if not isinstance(function, dict): raise ValueError( f"Invalid value for `function`: `{function}` in " f"`tool_choice`! {correct_usage_message}" ) if "name" not in function: raise ValueError( f"Expected field `name` in `function` in " f"`tool_choice`! {correct_usage_message}" ) function_name = function["name"] if not isinstance(function_name, str) or len(function_name) == 0: raise ValueError( f"Invalid `name` in `function`: `{function_name}`" f" in `tool_choice`! {correct_usage_message}" ) for tool in data["tools"]: if tool["function"]["name"] == function_name: valid_tool = True break if not valid_tool: raise ValueError( "The tool specified in `tool_choice` does not match any" " of the specified `tools`" ) return data @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): if data.get("continue_final_message") and data.get("add_generation_prompt"): raise ValueError( "Cannot set both `continue_final_message` and " "`add_generation_prompt` to True." ) return data @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): if data.get("cache_salt") is not None and ( not isinstance(data["cache_salt"], str) or not data["cache_salt"] ): raise ValueError( "Parameter 'cache_salt' must be a non-empty string if provided." ) return data @model_validator(mode="before") @classmethod def check_system_message_content_type(cls, data): """Warn if system messages contain non-text content. According to OpenAI API spec, system messages can only be of type 'text'. We log a warning instead of rejecting to avoid breaking users who intentionally send multimodal system messages. See: https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages-system_message """ if not isinstance(data, dict): return data messages = data.get("messages", []) for msg in messages: # Check if this is a system message if isinstance(msg, dict) and msg.get("role") == "system": content = msg.get("content") # If content is a list (multimodal format) if isinstance(content, list): for part in content: if isinstance(part, dict): part_type = part.get("type") # Infer type when 'type' field is not explicit if part_type is None: if "image_url" in part or "image_pil" in part: part_type = "image_url" elif "image_embeds" in part: part_type = "image_embeds" elif "audio_url" in part: part_type = "audio_url" elif "input_audio" in part: part_type = "input_audio" elif "audio_embeds" in part: part_type = "audio_embeds" elif "video_url" in part: part_type = "video_url" # Warn about non-text content in system messages if part_type and part_type != "text": logger.warning_once( "System messages should only contain text " "content according to the OpenAI API spec. " "Found content type: '%s'.", part_type, ) return data