Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.anthropic.protocol import (
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicError,
AnthropicErrorResponse,
AnthropicMessagesRequest,
@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/messages/count_tokens",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@load_aware_call
@with_cancellation
async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
error = base_server.create_error_response(
message="The model does not support Messages API"
)
return translate_error_response(error)
try:
response = await handler.count_tokens(request, raw_request)
except Exception as e:
logger.exception("Error in count_tokens: %s", e)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
content=AnthropicErrorResponse(
error=AnthropicError(
type="internal_error",
message=str(e),
)
).model_dump(),
)
if isinstance(response, ErrorResponse):
return translate_error_response(response)
return JSONResponse(content=response.model_dump(exclude_none=True))
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel):
class AnthropicContentBlock(BaseModel):
"""Content block in message"""
type: Literal["text", "image", "tool_use", "tool_result"]
type: Literal["text", "image", "tool_use", "tool_result", "thinking"]
text: str | None = None
# For image content
source: dict[str, Any] | None = None
@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel):
input: dict[str, Any] | None = None
content: str | list[dict[str, Any]] | None = None
is_error: bool | None = None
# For thinking content
thinking: str | None = None
signature: str | None = None
class AnthropicMessage(BaseModel):
@@ -74,7 +77,7 @@ class AnthropicTool(BaseModel):
class AnthropicToolChoice(BaseModel):
"""Tool Choice definition"""
type: Literal["auto", "any", "tool"]
type: Literal["auto", "any", "tool", "none"]
name: str | None = None
@model_validator(mode="after")
@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel):
class AnthropicDelta(BaseModel):
"""Delta for streaming responses"""
type: Literal["text_delta", "input_json_delta"] | None = None
type: (
Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"]
| None
) = None
text: str | None = None
thinking: str | None = None
partial_json: str | None = None
signature: str | None = None
# Message delta
stop_reason: (
@@ -167,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context):
if not self.id:
self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None

View File

@@ -8,6 +8,7 @@
import json
import logging
import time
import uuid
from collections.abc import AsyncGenerator
from typing import Any
@@ -16,6 +17,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicContentBlock,
AnthropicContextManagement,
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicDelta,
AnthropicError,
AnthropicMessagesRequest,
@@ -85,94 +89,225 @@ class AnthropicServingMessages(OpenAIServingChat):
"tool_calls": "tool_use",
}
@staticmethod
def _convert_image_source_to_url(source: dict[str, Any]) -> str:
"""Convert an Anthropic image source to an OpenAI-compatible URL.
Anthropic supports two image source types:
- base64: {"type": "base64", "media_type": "image/jpeg", "data": "..."}
- url: {"type": "url", "url": "https://..."}
For base64 sources, this constructs a proper data URI that
downstream processors (e.g. vLLM's media connector) can handle.
"""
source_type = source.get("type")
if source_type == "url":
return source.get("url", "")
# Default to base64 processing if type is "base64"
# or missing, ensuring a proper data URI is always
# constructed for non-URL sources.
media_type = source.get("media_type", "image/jpeg")
data = source.get("data", "")
return f"data:{media_type};base64,{data}"
@classmethod
def _convert_anthropic_to_openai_request(
self, anthropic_request: AnthropicMessagesRequest
cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
) -> ChatCompletionRequest:
"""Convert Anthropic message format to OpenAI format"""
openai_messages = []
openai_messages: list[dict[str, Any]] = []
# Add system message if provided
if anthropic_request.system:
if isinstance(anthropic_request.system, str):
openai_messages.append(
{"role": "system", "content": anthropic_request.system}
)
else:
system_prompt = ""
for block in anthropic_request.system:
if block.type == "text" and block.text:
system_prompt += block.text
openai_messages.append({"role": "system", "content": system_prompt})
cls._convert_system_message(anthropic_request, openai_messages)
cls._convert_messages(anthropic_request.messages, openai_messages)
req = cls._build_base_request(anthropic_request, openai_messages)
cls._handle_streaming_options(req, anthropic_request)
cls._convert_tool_choice(anthropic_request, req)
cls._convert_tools(anthropic_request, req)
return req
for msg in anthropic_request.messages:
@classmethod
def _convert_system_message(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert Anthropic system message to OpenAI format"""
if not anthropic_request.system:
return
if isinstance(anthropic_request.system, str):
openai_messages.append(
{"role": "system", "content": anthropic_request.system}
)
else:
system_prompt = ""
for block in anthropic_request.system:
if block.type == "text" and block.text:
system_prompt += block.text
openai_messages.append({"role": "system", "content": system_prompt})
@classmethod
def _convert_messages(
cls, messages: list, openai_messages: list[dict[str, Any]]
) -> None:
"""Convert Anthropic messages to OpenAI format"""
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
# Handle complex content blocks
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
for block in msg.content:
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
content_parts.append(
{
"type": "image_url",
"image_url": {"url": block.source.get("data", "")},
}
)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.tool_use_id or "",
"content": str(block.content)
if block.content
else "",
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
}
)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
# Add content parts if any
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls:
continue
cls._convert_message_content(msg, openai_msg, openai_messages)
openai_messages.append(openai_msg)
req = ChatCompletionRequest(
@classmethod
def _convert_message_content(
cls,
msg,
openai_msg: dict[str, Any],
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert complex message content blocks"""
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
cls._convert_block(
block,
msg.role,
content_parts,
tool_calls,
reasoning_parts,
openai_messages,
)
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls and not reasoning_parts:
return
@classmethod
def _convert_block(
cls,
block,
role: str,
content_parts: list[dict[str, Any]],
tool_calls: list[dict[str, Any]],
reasoning_parts: list[str],
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert individual content block"""
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
image_url = cls._convert_image_source_to_url(block.source)
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
cls._convert_tool_use_block(block, tool_calls)
elif block.type == "tool_result":
cls._convert_tool_result_block(block, role, openai_messages, content_parts)
@classmethod
def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
"""Convert tool_use block to OpenAI function call format"""
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
@classmethod
def _convert_tool_result_block(
cls,
block,
role: str,
openai_messages: list[dict[str, Any]],
content_parts: list[dict[str, Any]],
) -> None:
"""Convert tool_result block to OpenAI format"""
if role == "user":
cls._convert_user_tool_result(block, openai_messages)
else:
tool_result_text = str(block.content) if block.content else ""
content_parts.append(
{"type": "text", "text": f"Tool result: {tool_result_text}"}
)
@classmethod
def _convert_user_tool_result(
cls, block, openai_messages: list[dict[str, Any]]
) -> None:
"""Convert user tool_result with text and image support"""
tool_text = ""
tool_image_urls: list[str] = []
if isinstance(block.content, str):
tool_text = block.content
elif isinstance(block.content, list):
text_parts: list[str] = []
for item in block.content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text_parts.append(item.get("text", ""))
elif item_type == "image":
source = item.get("source", {})
url = cls._convert_image_source_to_url(source)
if url:
tool_image_urls.append(url)
tool_text = "\n".join(text_parts)
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.tool_use_id or "",
"content": tool_text or "",
}
)
if tool_image_urls:
openai_messages.append(
{
"role": "user",
"content": [ # type: ignore[dict-item]
{"type": "image_url", "image_url": {"url": img}}
for img in tool_image_urls
],
}
)
@classmethod
def _build_base_request(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
openai_messages: list[dict[str, Any]],
) -> ChatCompletionRequest:
"""Build base ChatCompletionRequest"""
if isinstance(anthropic_request, AnthropicCountTokensRequest):
return ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
)
return ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
max_tokens=anthropic_request.max_tokens,
@@ -183,19 +318,40 @@ class AnthropicServingMessages(OpenAIServingChat):
top_k=anthropic_request.top_k,
)
@classmethod
def _handle_streaming_options(
cls,
req: ChatCompletionRequest,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
) -> None:
"""Handle streaming configuration"""
if isinstance(anthropic_request, AnthropicCountTokensRequest):
return
if anthropic_request.stream:
req.stream = anthropic_request.stream
req.stream_options = StreamOptions.validate(
req.stream_options = StreamOptions.model_validate(
{"include_usage": True, "continuous_usage_stats": True}
)
@classmethod
def _convert_tool_choice(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
req: ChatCompletionRequest,
) -> None:
"""Convert Anthropic tool_choice to OpenAI format"""
if anthropic_request.tool_choice is None:
req.tool_choice = None
elif anthropic_request.tool_choice.type == "auto":
return
tool_choice_type = anthropic_request.tool_choice.type
if tool_choice_type == "auto":
req.tool_choice = "auto"
elif anthropic_request.tool_choice.type == "any":
elif tool_choice_type == "any":
req.tool_choice = "required"
elif anthropic_request.tool_choice.type == "tool":
elif tool_choice_type == "none":
req.tool_choice = "none"
elif tool_choice_type == "tool":
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
{
"type": "function",
@@ -203,9 +359,17 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
tools = []
@classmethod
def _convert_tools(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
req: ChatCompletionRequest,
) -> None:
"""Convert Anthropic tools to OpenAI format"""
if anthropic_request.tools is None:
return req
return
tools = []
for tool in anthropic_request.tools:
tools.append(
ChatCompletionToolsParam.model_validate(
@@ -219,10 +383,10 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
)
if req.tool_choice is None:
req.tool_choice = "auto"
req.tools = tools
return req
async def create_messages(
self,
@@ -263,23 +427,32 @@ class AnthropicServingMessages(OpenAIServingChat):
output_tokens=generator.usage.completion_tokens,
),
)
if generator.choices[0].finish_reason == "stop":
choice = generator.choices[0]
if choice.finish_reason == "stop":
result.stop_reason = "end_turn"
elif generator.choices[0].finish_reason == "length":
elif choice.finish_reason == "length":
result.stop_reason = "max_tokens"
elif generator.choices[0].finish_reason == "tool_calls":
elif choice.finish_reason == "tool_calls":
result.stop_reason = "tool_use"
content: list[AnthropicContentBlock] = [
AnthropicContentBlock(
type="text",
text=generator.choices[0].message.content
if generator.choices[0].message.content
else "",
content: list[AnthropicContentBlock] = []
if choice.message.reasoning:
content.append(
AnthropicContentBlock(
type="thinking",
thinking=choice.message.reasoning,
signature=uuid.uuid4().hex,
)
)
if choice.message.content:
content.append(
AnthropicContentBlock(
type="text",
text=choice.message.content,
)
)
]
for tool_call in generator.choices[0].message.tool_calls:
for tool_call in choice.message.tool_calls:
anthropic_tool_call = AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
@@ -297,10 +470,85 @@ class AnthropicServingMessages(OpenAIServingChat):
generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
try:
class _ActiveBlockState:
def __init__(self) -> None:
self.content_block_index = 0
self.block_type: str | None = None
self.block_index: int | None = None
self.block_signature: str | None = None
self.signature_emitted: bool = False
self.tool_use_id: str | None = None
def reset(self) -> None:
self.block_type = None
self.block_index = None
self.block_signature = None
self.signature_emitted = False
self.tool_use_id = None
def start(self, block: AnthropicContentBlock) -> None:
self.block_type = block.type
self.block_index = self.content_block_index
if block.type == "thinking":
self.block_signature = uuid.uuid4().hex
self.signature_emitted = False
self.tool_use_id = None
elif block.type == "tool_use":
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = block.id
else:
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = None
first_item = True
finish_reason = None
content_block_index = 0
content_block_started = False
state = _ActiveBlockState()
# Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
events: list[str] = []
if state.block_type is None:
return events
if (
state.block_type == "thinking"
and state.block_signature is not None
and not state.signature_emitted
):
chunk = AnthropicStreamEvent(
index=state.block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=state.block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
state.signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=state.block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
state.reset()
state.content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
chunk = AnthropicStreamEvent(
index=state.content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
state.start(block)
return event
async for item in generator:
if item.startswith("data:"):
@@ -326,6 +574,8 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id,
content=[],
model=origin_chunk.model,
stop_reason=None,
stop_sequence=None,
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
@@ -341,13 +591,8 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info
if len(origin_chunk.choices) == 0:
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_stop")
for event in stop_active_block():
yield event
stop_reason = self.stop_reason_map.get(
finish_reason or "stop"
)
@@ -369,96 +614,139 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason
continue
# content
if origin_chunk.choices[0].delta.content is not None:
if not content_block_started:
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="text", text=""
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
if origin_chunk.choices[0].delta.content == "":
continue
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="text_delta",
text=origin_chunk.choices[0].delta.content,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
continue
# tool calls
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
if tool_call.id is not None:
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(
exclude_unset=True
)
yield wrap_data_with_event(
data, "content_block_stop"
)
content_block_started = False
content_block_index += 1
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name
if tool_call.function
else None,
input={},
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
if tool_call.function and tool_call.function.arguments:
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
# continue
# thinking / text content
reasoning_delta = origin_chunk.choices[0].delta.reasoning
if reasoning_delta is not None:
if reasoning_delta == "":
pass
else:
if state.block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="thinking", thinking=""
)
)
yield start_event
chunk = AnthropicStreamEvent(
index=content_block_index,
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments
if tool_call.function
else None,
type="thinking_delta",
thinking=reasoning_delta,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
if origin_chunk.choices[0].delta.content is not None:
if origin_chunk.choices[0].delta.content == "":
pass
else:
if state.block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(type="text", text="")
)
yield start_event
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="text_delta",
text=origin_chunk.choices[0].delta.content,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
# tool calls - process all tool calls in the delta
if len(origin_chunk.choices[0].delta.tool_calls) > 0:
for tool_call in origin_chunk.choices[0].delta.tool_calls:
if tool_call.id is not None:
# Update mapping for incremental updates
tool_index_to_id[tool_call.index] = tool_call.id
# Only create new block if different tool call
# AND has a name
tool_name = (
tool_call.function.name
if tool_call.function
else None
)
if (
state.tool_use_id != tool_call.id
and tool_name is not None
):
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_name,
input={},
)
)
yield start_event
# Handle initial arguments if present
if (
tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_call.id
):
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
else:
# Incremental update - use index to find tool_use_id
tool_use_id = tool_index_to_id.get(tool_call.index)
if (
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
continue
else:
error_response = AnthropicStreamEvent(
@@ -481,3 +769,31 @@ class AnthropicServingMessages(OpenAIServingChat):
data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n"
async def count_tokens(
self,
request: AnthropicCountTokensRequest,
raw_request: Request | None = None,
) -> AnthropicCountTokensResponse | ErrorResponse:
"""Implements Anthropic's messages.count_tokens endpoint."""
chat_req = self._convert_anthropic_to_openai_request(request)
result = await self.render_chat_request(chat_req)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
input_tokens = sum( # type: ignore
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
for prompt in engine_prompts
if "prompt_token_ids" in prompt
)
response = AnthropicCountTokensResponse(
input_tokens=input_tokens,
context_management=AnthropicContextManagement(
original_input_tokens=input_tokens
),
)
return response

View File

@@ -7,6 +7,7 @@ import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from functools import cached_property, lru_cache, partial
from itertools import accumulate
from pathlib import Path
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("video", placeholder)
@dataclass
class ChatTemplateConfig:
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
trust_request_chat_template: bool = False
def validate_chat_template(chat_template: Path | str | None):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:

View File

@@ -1,12 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.mm_processor import (
BenchmarkMMProcessorSubcommand,
)
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
# Keep this package init import-free.
#
# The `vllm` console script imports `vllm.entrypoints.cli.main`, which causes
# Python to import this package before loading the `main` submodule.
# Eagerly importing benchmark subcommands here makes every `vllm serve ...`
# startup depend on optional benchmark-only modules.
#
# Benchmark subcommands are loaded on demand in
# `vllm.entrypoints.cli.benchmark.main`.
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkMMProcessorSubcommand",
"BenchmarkServingSubcommand",
"BenchmarkStartupSubcommand",
"BenchmarkSweepSubcommand",
"BenchmarkThroughputSubcommand",
]

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import importlib
import logging
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
@@ -15,30 +13,6 @@ if typing.TYPE_CHECKING:
else:
FlexibleArgumentParser = argparse.ArgumentParser
logger = logging.getLogger(__name__)
def _load_benchmark_subcommands() -> None:
modules = [
"vllm.entrypoints.cli.benchmark.latency",
"vllm.entrypoints.cli.benchmark.mm_processor",
"vllm.entrypoints.cli.benchmark.serve",
"vllm.entrypoints.cli.benchmark.startup",
"vllm.entrypoints.cli.benchmark.sweep",
"vllm.entrypoints.cli.benchmark.throughput",
]
for module_name in modules:
try:
importlib.import_module(module_name)
except ModuleNotFoundError as e:
logger.warning(
"Skipping benchmark subcommand module %s because an optional "
"dependency could not be imported: %r",
module_name,
e,
)
class BenchmarkSubcommand(CLISubcommand):
"""The `bench` subcommand for the vLLM CLI."""
@@ -64,8 +38,6 @@ class BenchmarkSubcommand(CLISubcommand):
)
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
_load_benchmark_subcommands()
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
cmd_subparser = bench_subparsers.add_parser(
cmd_cls.name,

View File

@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace):
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
# TODO(wentao): remove this once well tested
raise ValueError(
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
)
if num_api_servers > 1:
setup_multiprocess_prometheus()
@@ -246,8 +252,12 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager: APIServerProcessManager | None = None
from vllm.v1.engine.utils import get_engine_zmq_addresses
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(

View File

@@ -101,11 +101,15 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
sampling_params = self._sampling_params_from_proto(
request.sampling_params, stream=request.stream
)
tokenization_kwargs = self._tokenization_kwargs_from_proto(
request.sampling_params
)
async for output in self.async_llm.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
):
# Convert vLLM output to protobuf
# For streaming, always send chunks
@@ -308,9 +312,6 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
seed=params.seed if params.HasField("seed") else None,
include_stop_str_in_output=params.include_stop_str_in_output,
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
truncate_prompt_tokens=params.truncate_prompt_tokens
if params.HasField("truncate_prompt_tokens")
else None,
structured_outputs=structured_outputs,
# detokenize must be True if stop strings are used
detokenize=bool(stop),
@@ -319,6 +320,14 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
else RequestOutputKind.FINAL_ONLY,
)
@staticmethod
def _tokenization_kwargs_from_proto(
params: vllm_engine_pb2.SamplingParams,
) -> dict[str, int] | None:
if params.HasField("truncate_prompt_tokens"):
return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
return None
@staticmethod
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
"""

View File

@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import warnings
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
import cloudpickle
@@ -41,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
load_chat_template,
)
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.score.utils import (
ScoreData,
ScoreMultiModalParam,
@@ -146,6 +149,7 @@ class LLM:
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
chat_template: The chat template to apply.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
@@ -233,6 +237,7 @@ class LLM:
quantization: QuantizationMethods | None = None,
revision: str | None = None,
tokenizer_revision: str | None = None,
chat_template: Path | str | None = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
@@ -385,9 +390,16 @@ class LLM:
self.model_config = self.llm_engine.model_config
self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
self.init_pooling_io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks,
model_config=self.model_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
@@ -1030,7 +1042,6 @@ class LLM:
prompts: PromptType | Sequence[PromptType] | DataPrompt,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
*,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
pooling_task: PoolingTask | None = None,
@@ -1088,21 +1099,7 @@ class LLM:
"pooling model."
)
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
@@ -1136,6 +1133,31 @@ class LLM:
for p in params_seq:
if p.task is None:
p.task = "plugin"
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
]
else:
if pooling_params is None:
# Use default pooling params.
@@ -1153,39 +1175,42 @@ class LLM:
)
raise ValueError(msg)
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
if pooling_task in self.init_pooling_io_processors:
io_processor = self.init_pooling_io_processors[pooling_task]
processor_inputs = io_processor.pre_process_offline(
prompts_seq, tokenization_kwargs
)
]
seq_lora_requests = self._lora_request_to_seq(
lora_request, len(prompts_seq)
)
seq_priority = self._priority_to_seq(None, len(prompts))
self._render_and_add_requests(
prompts=processor_inputs,
params=params_seq,
lora_requests=seq_lora_requests,
priorities=seq_priority,
)
outputs = self._run_engine(
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
)
outputs = io_processor.post_process(outputs)
else:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return outputs
def embed(
self,
prompts: PromptType | Sequence[PromptType],
*,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
@@ -1221,12 +1246,6 @@ class LLM:
"Try converting the model using `--convert embed`."
)
if truncate_prompt_tokens is not None:
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
items = self.encode(
prompts,
use_tqdm=use_tqdm,
@@ -1294,7 +1313,6 @@ class LLM:
/,
*,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
@@ -1319,13 +1337,11 @@ class LLM:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
"""
return self.encode(
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="token_classify",
tokenization_kwargs=tokenization_kwargs,
)
@@ -1771,23 +1787,15 @@ class LLM:
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
return self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
),
seq_tok_kwargs,
self._preprocess_cmpl_one(prompt, tokenization_kwargs)
for prompt in maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
)
),
params=seq_params,
@@ -1841,13 +1849,6 @@ class LLM:
seq_convs = conversation_to_seq(messages)
seq_params = self._params_to_seq(params, len(seq_convs))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
return self._render_and_run_requests(
prompts=(
@@ -1859,16 +1860,13 @@ class LLM:
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tok_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
for conversation, tok_kwargs in zip(
maybe_tqdm(
seq_convs,
use_tqdm=use_tqdm,
desc="Rendering conversations",
),
seq_tok_kwargs,
for conversation in maybe_tqdm(
seq_convs,
use_tqdm=use_tqdm,
desc="Rendering conversations",
)
),
params=seq_params,

View File

@@ -18,6 +18,20 @@ class RequestLogger:
def __init__(self, *, max_log_len: int | None) -> None:
self.max_log_len = max_log_len
if not logger.isEnabledFor(logging.INFO):
logger.warning_once(
"`--enable-log-requests` is set but "
"the minimum log level is higher than INFO. "
"No request information will be logged."
)
elif not logger.isEnabledFor(logging.DEBUG):
logger.info_once(
"`--enable-log-requests` is set but "
"the minimum log level is higher than DEBUG. "
"Only limited information will be logged to minimize overhead. "
"To view more details, set `VLLM_LOGGING_LEVEL=DEBUG`."
)
def log_inputs(
self,
request_id: str,

View File

@@ -38,6 +38,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
repetition_detection: RepetitionDetectionParams | None = Field(
default=None,
description="Parameters for detecting repetitive N-gram patterns "
"in output tokens. If such repetition is detected, generation will "
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
"token patterns, stopping only when they hit the maximum output length "
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
"can detect such behavior and terminate early, saving time and tokens.",
)
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
@@ -490,7 +501,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
@@ -500,8 +510,37 @@ class ChatCompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
repetition_detection=self.repetition_detection,
)
@model_validator(mode="before")
@classmethod
def validate_response_format(cls, data):
response_format = data.get("response_format")
if response_format is None:
return data
rf_type = (
response_format.get("type")
if isinstance(response_format, dict)
else getattr(response_format, "type", None)
)
if rf_type == "json_schema":
json_schema = (
response_format.get("json_schema")
if isinstance(response_format, dict)
else getattr(response_format, "json_schema", None)
)
if json_schema is None:
raise VLLMValidationError(
"When response_format type is 'json_schema', the "
"'json_schema' field must be provided.",
parameter="response_format",
)
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):

View File

@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
),
ensure_ascii=False,
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# spurious final delta.
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
)
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool

View File

@@ -143,7 +143,8 @@ class BaseFrontendArgs:
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If set to True, log model outputs (generations).
Requires --enable-log-requests."""
Requires `--enable-log-requests`. As with `--enable-log-requests`,
information is only logged at INFO level at maximum."""
enable_log_deltas: bool = True
"""If set to False, output deltas will not be logged. Relevant only if
--enable-log-outputs is set.
@@ -277,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
use_gpu_for_pooling_score: bool = False
"""If set, run pooling score MaxSim on GPU in the API server process.
Can significantly improve late-interaction scoring performance.
https://github.com/vllm-project/vllm/pull/35330"""
@classmethod
def _customize_cli_kwargs(

View File

@@ -26,6 +26,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel):
),
)
repetition_detection: RepetitionDetectionParams | None = Field(
default=None,
description="Parameters for detecting repetitive N-gram patterns "
"in output tokens. If such repetition is detected, generation will "
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
"token patterns, stopping only when they hit the maximum output length "
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
"can detect such behavior and terminate early, saving time and tokens.",
)
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
@@ -259,7 +270,7 @@ class CompletionRequest(OpenAIBaseModel):
structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
assert isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
@@ -302,7 +313,6 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
@@ -311,8 +321,37 @@ class CompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
repetition_detection=self.repetition_detection,
)
@model_validator(mode="before")
@classmethod
def validate_response_format(cls, data):
response_format = data.get("response_format")
if response_format is None:
return data
rf_type = (
response_format.get("type")
if isinstance(response_format, dict)
else getattr(response_format, "type", None)
)
if rf_type == "json_schema":
json_schema = (
response_format.get("json_schema")
if isinstance(response_format, dict)
else getattr(response_format, "json_schema", None)
)
if json_schema is None:
raise VLLMValidationError(
"When response_format type is 'json_schema', the "
"'json_schema' field must be provided.",
parameter="response_format",
)
return data
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):

View File

@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranslationRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
| TokenizeCompletionRequest
| DetokenizeRequest
| EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
ChatCompletionRequest
| TokenizeChatRequest
| EmbeddingChatRequest
| ClassificationChatRequest
| PoolingChatRequest
)
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
| ClassificationResponse
| ScoreResponse
| GenerateResponse
)
RequestT = TypeVar("RequestT", bound=AnyRequest)
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
so you can easily tell “this ID came from Embedding vs Classification.”
A short string prepended to every requests ID (e.g. "embd")
so you can easily tell “this ID came from Embedding.”
"""
def __init__(
@@ -456,7 +447,7 @@ class OpenAIServing:
) -> ErrorResponse | None:
"""
Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.).
to prepare `ctx` (embedding, etc.).
"""
return None
@@ -817,7 +808,7 @@ class OpenAIServing:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: EmbeddingRequest, ClassificationRequest,
# Note: EmbeddingRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(
request,
@@ -828,8 +819,6 @@ class OpenAIServing:
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
ClassificationCompletionRequest,
ClassificationChatRequest,
),
):
# Note: input length can be up to the entire model context length
@@ -839,8 +828,6 @@ class OpenAIServing:
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
ClassificationCompletionRequest: "classification",
ClassificationChatRequest: "classification",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(

View File

@@ -328,8 +328,9 @@ class ResponsesRequest(OpenAIBaseModel):
# Also check text.format for OpenAI-style json_schema
if self.text is not None and self.text.format is not None:
if structured_outputs is not None:
raise ValueError(
"Cannot specify both structured_outputs and text.format"
raise VLLMValidationError(
"Cannot specify both structured_outputs and text.format",
parameter="structured_outputs",
)
response_format = self.text.format
if (
@@ -378,14 +379,19 @@ class ResponsesRequest(OpenAIBaseModel):
)
@model_validator(mode="before")
@classmethod
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
raise ValueError("background can only be used when `store` is true")
raise VLLMValidationError(
"background can only be used when `store` is true",
parameter="background",
)
return data
@model_validator(mode="before")
@classmethod
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise VLLMValidationError(
@@ -394,16 +400,19 @@ class ResponsesRequest(OpenAIBaseModel):
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
raise VLLMValidationError(
"Parameter 'cache_salt' must be a non-empty string if provided.",
parameter="cache_salt",
)
return data
@model_validator(mode="before")
@classmethod
def function_call_parsing(cls, data):
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
This ensures Pydantic can properly resolve union types in the input field.

View File

@@ -85,6 +85,8 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseInputOutputMessage,
ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent,
ResponsesRequest,
ResponsesResponse,
ResponseUsage,
@@ -1339,6 +1341,19 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
yield _increment_sequence_number_and_return(
ResponseReasoningPartAddedEvent(
type="response.reasoning_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text="",
type="reasoning_text",
),
)
)
else:
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
@@ -1354,22 +1369,21 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
yield _increment_sequence_number_and_return(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
yield _increment_sequence_number_and_return(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
)
current_content_index += 1
first_delta_sent = True
# todo(kebe7jun) tool call support
@@ -1397,6 +1411,19 @@ class OpenAIServingResponses(OpenAIServing):
text=reason_content,
)
)
yield _increment_sequence_number_and_return(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text=reason_content,
type="reasoning_text",
),
)
)
current_content_index = 0
reasoning_item = ResponseReasoningItem(
type="reasoning",
@@ -1418,6 +1445,8 @@ class OpenAIServingResponses(OpenAIServing):
item=reasoning_item,
)
)
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@@ -1432,8 +1461,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
ResponseContentPartAddedEvent(
type="response.content_part.added",
@@ -1449,7 +1476,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_content_index += 1
# reset previous delta messages
previous_delta_messages = []
@@ -1485,7 +1511,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_content_index += 1
previous_delta_messages.append(delta_message)
if previous_delta_messages:
@@ -1505,7 +1530,19 @@ class OpenAIServingResponses(OpenAIServing):
text=reason_content,
)
)
current_content_index += 1
yield _increment_sequence_number_and_return(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text=reason_content,
type="reasoning_text",
),
)
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
@@ -1543,7 +1580,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
)
)
current_content_index += 1
part = ResponseOutputText(
text=final_content,
type="output_text",
@@ -1559,7 +1595,6 @@ class OpenAIServingResponses(OpenAIServing):
part=part,
)
)
current_content_index += 1
item = ResponseOutputMessage(
type="message",
role="assistant",

View File

@@ -11,6 +11,7 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
from soundfile import LibsndfileError
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
@@ -57,6 +58,14 @@ try:
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
@@ -315,9 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
)
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
try:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
except LibsndfileError as exc:
# Distinguish client errors (invalid audio) from server errors
if exc.code in _BAD_SF_CODES:
raise ValueError("Invalid or unsupported audio file.") from exc
raise
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"The 'vllm.entrypoints.openai.translations' module has been renamed to "
"'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
"This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -1,14 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.api_router' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.api_router'. Please update your "
"imports. This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.api_router import * # noqa: F401,F403,E402

View File

@@ -1,14 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.protocol' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.protocol'. Please update your "
"imports. This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.protocol import * # noqa: F401,F403,E402

View File

@@ -1,14 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.serving' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.serving'. Please update your "
"imports. This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.serving import * # noqa: F401,F403,E402

View File

@@ -1,15 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"'vllm.entrypoints.openai.translations.speech_to_text' has been moved to "
"'vllm.entrypoints.openai.speech_to_text.speech_to_text'. Please update "
"your imports. This backward-compatible alias will be removed in version "
"0.17+.",
DeprecationWarning,
stacklevel=2,
)
from vllm.entrypoints.openai.speech_to_text.speech_to_text import * # noqa: F401,F403,E402

View File

@@ -115,6 +115,7 @@ def init_pooling_state(
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Final
from vllm import PoolingRequestOutput, PromptType
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
from vllm.inputs import ProcessorInputs, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer
class PoolingIOProcessor:
def __init__(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
):
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self.model_config = model_config
self.renderer = renderer
self.chat_template = chat_template_config.chat_template
self.chat_template_content_format: Final = (
chat_template_config.chat_template_content_format
)
self.trust_request_chat_template = (
chat_template_config.trust_request_chat_template
)
def pre_process_online(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_online_async(self, *args, **kwargs):
return self.pre_process_online(*args, **kwargs)
def pre_process_offline(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_offline_async(self, *args, **kwargs):
return self.pre_process_offline(*args, **kwargs)
def post_process(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return outputs
async def post_process_async(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return self.post_process(outputs)
def create_pooling_params(self, request):
return request.to_pooling_params()
def _preprocess_completion_online(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
def _preprocess_chat_online(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
(conversation,), (engine_prompt,) = renderer.render_chat(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
return conversation, [engine_prompt]
def _preprocess_completion_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
prompts = prompt_to_seq(prompts)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
)
def _validate_chat_template(
self,
request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None,
trust_request_chat_template: bool,
):
if not trust_request_chat_template and (
request_chat_template is not None
or (
chat_template_kwargs
and chat_template_kwargs.get("chat_template") is not None
)
):
raise ValueError(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template."
)
return None

View File

@@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
description="Whether to use activation for the pooler outputs. "
"`None` uses the pooler's default, which is `True` in most cases.",
)
normalize: bool | None = Field(
default=None,
description="Deprecated; please pass `use_activation` instead",
)
# --8<-- [end:embed-extra-params]

View File

@@ -0,0 +1,378 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import ClassVar, Generic, TypeVar
from fastapi import Request
from pydantic import ConfigDict
from starlette.datastructures import Headers
from starlette.responses import JSONResponse
from vllm import (
PoolingParams,
PoolingRequestOutput,
PromptType,
SamplingParams,
envs,
)
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateConfig,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
from vllm.inputs import ProcessorInputs
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import BeamSearchParams
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators
from ...utils import create_error_response
from .io_processor import PoolingIOProcessor
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
request: PoolingRequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class PoolingServing:
request_id_prefix: ClassVar[str]
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.models = models
self.model_config = models.model_config
self.max_model_len = self.model_config.max_model_len
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.chat_template_config = ChatTemplateConfig(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template,
)
self.io_processor = self.init_io_processor(
model_config=models.model_config,
renderer=models.renderer,
chat_template_config=self.chat_template_config,
)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request,
) -> JSONResponse:
try:
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
)
await self._check_model(request)
ctx = PoolingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
self._validate_request(ctx)
self._maybe_get_adapters(ctx)
await self._preprocess(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
response = await self._build_response(ctx)
return JSONResponse(content=response.model_dump())
except Exception as e:
error_response = create_error_response(e)
return JSONResponse(
content=error_response.model_dump(),
status_code=error_response.error.code,
)
async def _preprocess(
self,
ctx: PoolingServeContext,
):
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
ctx.request
)
async def _prepare_generators(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self.io_processor.create_pooling_params(ctx.request)
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
async def _collect_batch(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
if ctx.result_generator is None:
raise ValueError("Result generator not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
raise ValueError("Failed to generate results for all prompts")
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
async def _build_response(
self,
ctx: PoolingServeContext,
) -> AnyPoolingResponse:
raise NotImplementedError
@staticmethod
def _base_request_id(
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
if raw_request is not None and (
(req_id := raw_request.headers.get("X-Request-Id")) is not None
):
return req_id
return random_uuid() if default is None else default
def _is_model_supported(self, model_name: str | None) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
async def _check_model(
self,
request: AnyPoolingRequest,
) -> ErrorResponse | None:
if self._is_model_supported(request.model):
return None
if request.model in self.models.lora_requests:
return None
if (
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
and request.model
and (load_result := await self.models.resolve_lora(request.model))
):
if isinstance(load_result, LoRARequest):
return None
if (
isinstance(load_result, ErrorResponse)
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
):
raise ValueError(load_result.error.message)
return None
def _validate_request(self, ctx: PoolingServeContext) -> None:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
if (
truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.max_model_len
):
raise ValueError(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size."
)
return None
async def _get_trace_headers(
self,
headers: Headers,
) -> Mapping[str, str] | None:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
def _maybe_get_adapters(
self,
ctx: PoolingServeContext,
supports_default_mm_loras: bool = False,
):
request = ctx.request
if request.model in self.models.lora_requests:
ctx.lora_request = self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
ctx.lora_request = default_mm_lora
if self._is_model_supported(request.model):
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_active_default_mm_loras(
self, request: AnyPoolingRequest
) -> LoRARequest | None:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
messages = request.messages
if messages is None or isinstance(messages, (str, bytes)):
return message_types
for message in messages:
if (
isinstance(message, dict)
and "content" in message
and isinstance(message["content"], list)
):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
def _log_inputs(
self,
request_id: str,
inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
components = extract_prompt_components(self.model_config, inputs)
self.request_logger.log_inputs(
request_id,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)

View File

@@ -3,16 +3,17 @@
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.entrypoints.utils import (
create_error_response,
load_aware_call,
with_cancellation,
)
router = APIRouter()
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request):
async def create_classify(
request: ClassificationRequest, raw_request: Request
) -> JSONResponse:
handler = classify(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
error_response = create_error_response(
message="The model does not support Classification API"
)
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
content=error_response.model_dump(),
status_code=error_response.error.code,
)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
return await handler(request, raw_request)

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from vllm import PromptType
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
)
from vllm.inputs import ProcessorInputs
from vllm.renderers.inputs import TokPrompt
class ClassifyIOProcessor(PoolingIOProcessor):
def pre_process_online(
self, request: ClassificationCompletionRequest | ClassificationChatRequest
) -> list[TokPrompt] | None:
if isinstance(request, ClassificationChatRequest):
self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
_, engine_prompts = self._preprocess_chat_online(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, ClassificationCompletionRequest):
engine_prompts = self._preprocess_completion_online(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError("Invalid classification request type")
return engine_prompts
def pre_process_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs
)

View File

@@ -40,7 +40,6 @@ class ClassificationCompletionRequest(
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -63,7 +62,6 @@ class ClassificationChatRequest(
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)

View File

@@ -1,116 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, TypeAlias
from typing import TypeAlias
import jinja2
import numpy as np
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
from vllm import ClassificationOutput
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
from vllm.logger import init_logger
from vllm.renderers import BaseRenderer
from .io_processor import ClassifyIOProcessor
from .protocol import (
ClassificationData,
ClassificationRequest,
ClassificationResponse,
)
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
logger = init_logger(__name__)
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
class ServingClassification(PoolingServing):
request_id_prefix = "classify"
def __init__(
def init_io_processor(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> ClassifyIOProcessor:
return ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
async def _build_response(
self,
ctx: ClassificationServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
) -> ClassificationResponse:
final_res_batch_checked = await self.io_processor.post_process_async(
ctx.final_res_batch
)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
ctx.request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, ClassificationCompletionRequest):
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = ctx.final_res_batch
items: list[ClassificationData] = []
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
data=items,
usage=usage,
)
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> ClassificationResponse | ErrorResponse:
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]

View File

@@ -14,12 +14,9 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
def _get_max_total_output_tokens(
model_config: ModelConfig,
@@ -60,18 +57,10 @@ class EmbeddingCompletionRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@@ -97,18 +86,10 @@ class EmbeddingChatRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask
def init_pooling_io_processors(
supported_tasks: tuple[SupportedTask, ...],
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> dict[str, PoolingIOProcessor]:
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.io_processor import (
ClassifyIOProcessor,
)
pooling_io_processors["classify"] = ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
return pooling_io_processors

View File

@@ -16,13 +16,10 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
logger = init_logger(__name__)
class PoolingCompletionRequest(
PoolingBasicRequestMixin,
@@ -45,16 +42,8 @@ class PoolingCompletionRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
)
@@ -78,16 +67,8 @@ class PoolingChatRequest(
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
)

View File

@@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)

View File

@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
parse_score_data_single,
validate_score_input,
@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
use_gpu_for_pooling_score: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
log_error_stack=log_error_stack,
)
self.score_template = score_template
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
@@ -311,19 +313,18 @@ class ServingScores(OpenAIServing):
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
maxsim_scores = compute_maxsim_scores(
[emb.outputs.data for emb in emb_data_1],
[emb.outputs.data for emb in emb_data_2],
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
)
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from typing import Any, TypeAlias, cast
import torch
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput
from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
@@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
return use_gpu_for_pooling_score and not current_platform.is_cpu()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000,
use_gpu_for_pooling_score: bool = False,
) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device(
current_platform.device_type
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
else "cpu"
)
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
dim = int(batch_q[0].shape[1])
dtype = batch_q[0].dtype
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=dtype, device=compute_device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=dtype, device=compute_device
)
q_mask = torch.zeros(
(batch_size, max_q), dtype=torch.bool, device=compute_device
)
d_mask = torch.zeros(
(batch_size, max_d), dtype=torch.bool, device=compute_device
)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0)
batch_scores = max_per_query.sum(dim=-1).to("cpu")
scores.extend(batch_scores.unbind(0))
start = end
return [cast(torch.Tensor, score) for score in scores]
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
)
PoolingChatLikeRequest: TypeAlias = (
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)
AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
)
AnyPoolingResponse: TypeAlias = (
ClassificationResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| PoolingResponse
| ScoreResponse
)

View File

@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None]
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]

View File

@@ -5,7 +5,10 @@ import asyncio
import dataclasses
import functools
import os
import sys
import traceback
from argparse import Namespace
from http import HTTPStatus
from logging import Logger
from string import Template
from typing import TYPE_CHECKING
@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.exceptions import VLLMValidationError
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
else:
StreamOptions = object
ErrorResponse = object
ErrorInfo = object
LoRAModulePath = object
StreamOptions = object
logger = init_logger(__name__)
@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
message = logo_template.substitute(colors)
lgr.info(message, version, model_name)
def create_error_response(
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
log_error_stack: bool = False,
) -> "ErrorResponse":
exc: Exception | None = None
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
if isinstance(message, Exception):
exc = message
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
)
)