Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicCountTokensRequest,
|
||||
AnthropicCountTokensResponse,
|
||||
AnthropicError,
|
||||
AnthropicErrorResponse,
|
||||
AnthropicMessagesRequest,
|
||||
@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages/count_tokens",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
|
||||
},
|
||||
)
|
||||
@load_aware_call
|
||||
@with_cancellation
|
||||
async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
error = base_server.create_error_response(
|
||||
message="The model does not support Messages API"
|
||||
)
|
||||
return translate_error_response(error)
|
||||
|
||||
try:
|
||||
response = await handler.count_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
logger.exception("Error in count_tokens: %s", e)
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
content=AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type="internal_error",
|
||||
message=str(e),
|
||||
)
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return translate_error_response(response)
|
||||
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel):
|
||||
class AnthropicContentBlock(BaseModel):
|
||||
"""Content block in message"""
|
||||
|
||||
type: Literal["text", "image", "tool_use", "tool_result"]
|
||||
type: Literal["text", "image", "tool_use", "tool_result", "thinking"]
|
||||
text: str | None = None
|
||||
# For image content
|
||||
source: dict[str, Any] | None = None
|
||||
@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel):
|
||||
input: dict[str, Any] | None = None
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
is_error: bool | None = None
|
||||
# For thinking content
|
||||
thinking: str | None = None
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
@@ -74,7 +77,7 @@ class AnthropicTool(BaseModel):
|
||||
class AnthropicToolChoice(BaseModel):
|
||||
"""Tool Choice definition"""
|
||||
|
||||
type: Literal["auto", "any", "tool"]
|
||||
type: Literal["auto", "any", "tool", "none"]
|
||||
name: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel):
|
||||
class AnthropicDelta(BaseModel):
|
||||
"""Delta for streaming responses"""
|
||||
|
||||
type: Literal["text_delta", "input_json_delta"] | None = None
|
||||
type: (
|
||||
Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"]
|
||||
| None
|
||||
) = None
|
||||
text: str | None = None
|
||||
thinking: str | None = None
|
||||
partial_json: str | None = None
|
||||
signature: str | None = None
|
||||
|
||||
# Message delta
|
||||
stop_reason: (
|
||||
@@ -167,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
|
||||
def model_post_init(self, __context):
|
||||
if not self.id:
|
||||
self.id = f"msg_{int(time.time() * 1000)}"
|
||||
|
||||
|
||||
class AnthropicContextManagement(BaseModel):
|
||||
"""Context management information for token counting."""
|
||||
|
||||
original_input_tokens: int
|
||||
|
||||
|
||||
class AnthropicCountTokensRequest(BaseModel):
|
||||
"""Anthropic messages.count_tokens request"""
|
||||
|
||||
model: str
|
||||
messages: list[AnthropicMessage]
|
||||
system: str | list[AnthropicContentBlock] | None = None
|
||||
tool_choice: AnthropicToolChoice | None = None
|
||||
tools: list[AnthropicTool] | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model is required")
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicCountTokensResponse(BaseModel):
|
||||
"""Anthropic messages.count_tokens response"""
|
||||
|
||||
input_tokens: int
|
||||
context_management: AnthropicContextManagement | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@@ -16,6 +17,9 @@ from fastapi import Request
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicContentBlock,
|
||||
AnthropicContextManagement,
|
||||
AnthropicCountTokensRequest,
|
||||
AnthropicCountTokensResponse,
|
||||
AnthropicDelta,
|
||||
AnthropicError,
|
||||
AnthropicMessagesRequest,
|
||||
@@ -85,94 +89,225 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
"tool_calls": "tool_use",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _convert_image_source_to_url(source: dict[str, Any]) -> str:
|
||||
"""Convert an Anthropic image source to an OpenAI-compatible URL.
|
||||
|
||||
Anthropic supports two image source types:
|
||||
- base64: {"type": "base64", "media_type": "image/jpeg", "data": "..."}
|
||||
- url: {"type": "url", "url": "https://..."}
|
||||
|
||||
For base64 sources, this constructs a proper data URI that
|
||||
downstream processors (e.g. vLLM's media connector) can handle.
|
||||
"""
|
||||
source_type = source.get("type")
|
||||
if source_type == "url":
|
||||
return source.get("url", "")
|
||||
# Default to base64 processing if type is "base64"
|
||||
# or missing, ensuring a proper data URI is always
|
||||
# constructed for non-URL sources.
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
return f"data:{media_type};base64,{data}"
|
||||
|
||||
@classmethod
|
||||
def _convert_anthropic_to_openai_request(
|
||||
self, anthropic_request: AnthropicMessagesRequest
|
||||
cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
|
||||
) -> ChatCompletionRequest:
|
||||
"""Convert Anthropic message format to OpenAI format"""
|
||||
openai_messages = []
|
||||
openai_messages: list[dict[str, Any]] = []
|
||||
|
||||
# Add system message if provided
|
||||
if anthropic_request.system:
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
)
|
||||
else:
|
||||
system_prompt = ""
|
||||
for block in anthropic_request.system:
|
||||
if block.type == "text" and block.text:
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
cls._convert_system_message(anthropic_request, openai_messages)
|
||||
cls._convert_messages(anthropic_request.messages, openai_messages)
|
||||
req = cls._build_base_request(anthropic_request, openai_messages)
|
||||
cls._handle_streaming_options(req, anthropic_request)
|
||||
cls._convert_tool_choice(anthropic_request, req)
|
||||
cls._convert_tools(anthropic_request, req)
|
||||
return req
|
||||
|
||||
for msg in anthropic_request.messages:
|
||||
@classmethod
|
||||
def _convert_system_message(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert Anthropic system message to OpenAI format"""
|
||||
if not anthropic_request.system:
|
||||
return
|
||||
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
)
|
||||
else:
|
||||
system_prompt = ""
|
||||
for block in anthropic_request.system:
|
||||
if block.type == "text" and block.text:
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
@classmethod
|
||||
def _convert_messages(
|
||||
cls, messages: list, openai_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert Anthropic messages to OpenAI format"""
|
||||
for msg in messages:
|
||||
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
|
||||
|
||||
if isinstance(msg.content, str):
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
# Handle complex content blocks
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for block in msg.content:
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": block.source.get("data", "")},
|
||||
}
|
||||
)
|
||||
elif block.type == "tool_use":
|
||||
# Convert tool use to function call format
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name or "",
|
||||
"arguments": json.dumps(block.input or {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif block.type == "tool_result":
|
||||
if msg.role == "user":
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.tool_use_id or "",
|
||||
"content": str(block.content)
|
||||
if block.content
|
||||
else "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Assistant tool result becomes regular text
|
||||
tool_result_text = (
|
||||
str(block.content) if block.content else ""
|
||||
)
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Tool result: {tool_result_text}",
|
||||
}
|
||||
)
|
||||
|
||||
# Add tool calls to the message if any
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
# Add content parts if any
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls:
|
||||
continue
|
||||
cls._convert_message_content(msg, openai_msg, openai_messages)
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
@classmethod
|
||||
def _convert_message_content(
|
||||
cls,
|
||||
msg,
|
||||
openai_msg: dict[str, Any],
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert complex message content blocks"""
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
for block in msg.content:
|
||||
cls._convert_block(
|
||||
block,
|
||||
msg.role,
|
||||
content_parts,
|
||||
tool_calls,
|
||||
reasoning_parts,
|
||||
openai_messages,
|
||||
)
|
||||
|
||||
if reasoning_parts:
|
||||
openai_msg["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls and not reasoning_parts:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _convert_block(
|
||||
cls,
|
||||
block,
|
||||
role: str,
|
||||
content_parts: list[dict[str, Any]],
|
||||
tool_calls: list[dict[str, Any]],
|
||||
reasoning_parts: list[str],
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert individual content block"""
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
image_url = cls._convert_image_source_to_url(block.source)
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
elif block.type == "thinking" and block.thinking is not None:
|
||||
reasoning_parts.append(block.thinking)
|
||||
elif block.type == "tool_use":
|
||||
cls._convert_tool_use_block(block, tool_calls)
|
||||
elif block.type == "tool_result":
|
||||
cls._convert_tool_result_block(block, role, openai_messages, content_parts)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
|
||||
"""Convert tool_use block to OpenAI function call format"""
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name or "",
|
||||
"arguments": json.dumps(block.input or {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_result_block(
|
||||
cls,
|
||||
block,
|
||||
role: str,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
content_parts: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert tool_result block to OpenAI format"""
|
||||
if role == "user":
|
||||
cls._convert_user_tool_result(block, openai_messages)
|
||||
else:
|
||||
tool_result_text = str(block.content) if block.content else ""
|
||||
content_parts.append(
|
||||
{"type": "text", "text": f"Tool result: {tool_result_text}"}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_user_tool_result(
|
||||
cls, block, openai_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert user tool_result with text and image support"""
|
||||
tool_text = ""
|
||||
tool_image_urls: list[str] = []
|
||||
|
||||
if isinstance(block.content, str):
|
||||
tool_text = block.content
|
||||
elif isinstance(block.content, list):
|
||||
text_parts: list[str] = []
|
||||
for item in block.content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item_type == "image":
|
||||
source = item.get("source", {})
|
||||
url = cls._convert_image_source_to_url(source)
|
||||
if url:
|
||||
tool_image_urls.append(url)
|
||||
tool_text = "\n".join(text_parts)
|
||||
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.tool_use_id or "",
|
||||
"content": tool_text or "",
|
||||
}
|
||||
)
|
||||
|
||||
if tool_image_urls:
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [ # type: ignore[dict-item]
|
||||
{"type": "image_url", "image_url": {"url": img}}
|
||||
for img in tool_image_urls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_base_request(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> ChatCompletionRequest:
|
||||
"""Build base ChatCompletionRequest"""
|
||||
if isinstance(anthropic_request, AnthropicCountTokensRequest):
|
||||
return ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
)
|
||||
|
||||
return ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
max_tokens=anthropic_request.max_tokens,
|
||||
@@ -183,19 +318,40 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
top_k=anthropic_request.top_k,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_streaming_options(
|
||||
cls,
|
||||
req: ChatCompletionRequest,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
) -> None:
|
||||
"""Handle streaming configuration"""
|
||||
if isinstance(anthropic_request, AnthropicCountTokensRequest):
|
||||
return
|
||||
if anthropic_request.stream:
|
||||
req.stream = anthropic_request.stream
|
||||
req.stream_options = StreamOptions.validate(
|
||||
req.stream_options = StreamOptions.model_validate(
|
||||
{"include_usage": True, "continuous_usage_stats": True}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_choice(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
req: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Convert Anthropic tool_choice to OpenAI format"""
|
||||
if anthropic_request.tool_choice is None:
|
||||
req.tool_choice = None
|
||||
elif anthropic_request.tool_choice.type == "auto":
|
||||
return
|
||||
|
||||
tool_choice_type = anthropic_request.tool_choice.type
|
||||
if tool_choice_type == "auto":
|
||||
req.tool_choice = "auto"
|
||||
elif anthropic_request.tool_choice.type == "any":
|
||||
elif tool_choice_type == "any":
|
||||
req.tool_choice = "required"
|
||||
elif anthropic_request.tool_choice.type == "tool":
|
||||
elif tool_choice_type == "none":
|
||||
req.tool_choice = "none"
|
||||
elif tool_choice_type == "tool":
|
||||
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
@@ -203,9 +359,17 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
@classmethod
|
||||
def _convert_tools(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
req: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Convert Anthropic tools to OpenAI format"""
|
||||
if anthropic_request.tools is None:
|
||||
return req
|
||||
return
|
||||
|
||||
tools = []
|
||||
for tool in anthropic_request.tools:
|
||||
tools.append(
|
||||
ChatCompletionToolsParam.model_validate(
|
||||
@@ -219,10 +383,10 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if req.tool_choice is None:
|
||||
req.tool_choice = "auto"
|
||||
req.tools = tools
|
||||
return req
|
||||
|
||||
async def create_messages(
|
||||
self,
|
||||
@@ -263,23 +427,32 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
output_tokens=generator.usage.completion_tokens,
|
||||
),
|
||||
)
|
||||
if generator.choices[0].finish_reason == "stop":
|
||||
choice = generator.choices[0]
|
||||
if choice.finish_reason == "stop":
|
||||
result.stop_reason = "end_turn"
|
||||
elif generator.choices[0].finish_reason == "length":
|
||||
elif choice.finish_reason == "length":
|
||||
result.stop_reason = "max_tokens"
|
||||
elif generator.choices[0].finish_reason == "tool_calls":
|
||||
elif choice.finish_reason == "tool_calls":
|
||||
result.stop_reason = "tool_use"
|
||||
|
||||
content: list[AnthropicContentBlock] = [
|
||||
AnthropicContentBlock(
|
||||
type="text",
|
||||
text=generator.choices[0].message.content
|
||||
if generator.choices[0].message.content
|
||||
else "",
|
||||
content: list[AnthropicContentBlock] = []
|
||||
if choice.message.reasoning:
|
||||
content.append(
|
||||
AnthropicContentBlock(
|
||||
type="thinking",
|
||||
thinking=choice.message.reasoning,
|
||||
signature=uuid.uuid4().hex,
|
||||
)
|
||||
)
|
||||
if choice.message.content:
|
||||
content.append(
|
||||
AnthropicContentBlock(
|
||||
type="text",
|
||||
text=choice.message.content,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
for tool_call in generator.choices[0].message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
anthropic_tool_call = AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
@@ -297,10 +470,85 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
generator: AsyncGenerator[str, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
|
||||
class _ActiveBlockState:
|
||||
def __init__(self) -> None:
|
||||
self.content_block_index = 0
|
||||
self.block_type: str | None = None
|
||||
self.block_index: int | None = None
|
||||
self.block_signature: str | None = None
|
||||
self.signature_emitted: bool = False
|
||||
self.tool_use_id: str | None = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.block_type = None
|
||||
self.block_index = None
|
||||
self.block_signature = None
|
||||
self.signature_emitted = False
|
||||
self.tool_use_id = None
|
||||
|
||||
def start(self, block: AnthropicContentBlock) -> None:
|
||||
self.block_type = block.type
|
||||
self.block_index = self.content_block_index
|
||||
if block.type == "thinking":
|
||||
self.block_signature = uuid.uuid4().hex
|
||||
self.signature_emitted = False
|
||||
self.tool_use_id = None
|
||||
elif block.type == "tool_use":
|
||||
self.block_signature = None
|
||||
self.signature_emitted = True
|
||||
self.tool_use_id = block.id
|
||||
else:
|
||||
self.block_signature = None
|
||||
self.signature_emitted = True
|
||||
self.tool_use_id = None
|
||||
|
||||
first_item = True
|
||||
finish_reason = None
|
||||
content_block_index = 0
|
||||
content_block_started = False
|
||||
state = _ActiveBlockState()
|
||||
# Map from tool call index to tool_use_id
|
||||
tool_index_to_id: dict[int, str] = {}
|
||||
|
||||
def stop_active_block():
|
||||
events: list[str] = []
|
||||
if state.block_type is None:
|
||||
return events
|
||||
if (
|
||||
state.block_type == "thinking"
|
||||
and state.block_signature is not None
|
||||
and not state.signature_emitted
|
||||
):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=state.block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="signature_delta",
|
||||
signature=state.block_signature,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
events.append(wrap_data_with_event(data, "content_block_delta"))
|
||||
state.signature_emitted = True
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=state.block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(exclude_unset=True)
|
||||
events.append(wrap_data_with_event(data, "content_block_stop"))
|
||||
state.reset()
|
||||
state.content_block_index += 1
|
||||
return events
|
||||
|
||||
def start_block(block: AnthropicContentBlock):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=state.content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=block,
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
event = wrap_data_with_event(data, "content_block_start")
|
||||
state.start(block)
|
||||
return event
|
||||
|
||||
async for item in generator:
|
||||
if item.startswith("data:"):
|
||||
@@ -326,6 +574,8 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
id=origin_chunk.id,
|
||||
content=[],
|
||||
model=origin_chunk.model,
|
||||
stop_reason=None,
|
||||
stop_sequence=None,
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
@@ -341,13 +591,8 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
|
||||
# last chunk including usage info
|
||||
if len(origin_chunk.choices) == 0:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_stop")
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
stop_reason = self.stop_reason_map.get(
|
||||
finish_reason or "stop"
|
||||
)
|
||||
@@ -369,96 +614,139 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
|
||||
if origin_chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = origin_chunk.choices[0].finish_reason
|
||||
continue
|
||||
|
||||
# content
|
||||
if origin_chunk.choices[0].delta.content is not None:
|
||||
if not content_block_started:
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="text", text=""
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
|
||||
if origin_chunk.choices[0].delta.content == "":
|
||||
continue
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="text_delta",
|
||||
text=origin_chunk.choices[0].delta.content,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
continue
|
||||
|
||||
# tool calls
|
||||
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
|
||||
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.id is not None:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(
|
||||
exclude_unset=True
|
||||
)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_stop"
|
||||
)
|
||||
content_block_started = False
|
||||
content_block_index += 1
|
||||
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name
|
||||
if tool_call.function
|
||||
else None,
|
||||
input={},
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_delta"
|
||||
)
|
||||
# continue
|
||||
|
||||
# thinking / text content
|
||||
reasoning_delta = origin_chunk.choices[0].delta.reasoning
|
||||
if reasoning_delta is not None:
|
||||
if reasoning_delta == "":
|
||||
pass
|
||||
else:
|
||||
if state.block_type != "thinking":
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
start_event = start_block(
|
||||
AnthropicContentBlock(
|
||||
type="thinking", thinking=""
|
||||
)
|
||||
)
|
||||
yield start_event
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments
|
||||
if tool_call.function
|
||||
else None,
|
||||
type="thinking_delta",
|
||||
thinking=reasoning_delta,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
|
||||
if origin_chunk.choices[0].delta.content is not None:
|
||||
if origin_chunk.choices[0].delta.content == "":
|
||||
pass
|
||||
else:
|
||||
if state.block_type != "text":
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
start_event = start_block(
|
||||
AnthropicContentBlock(type="text", text="")
|
||||
)
|
||||
yield start_event
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="text_delta",
|
||||
text=origin_chunk.choices[0].delta.content,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
|
||||
# tool calls - process all tool calls in the delta
|
||||
if len(origin_chunk.choices[0].delta.tool_calls) > 0:
|
||||
for tool_call in origin_chunk.choices[0].delta.tool_calls:
|
||||
if tool_call.id is not None:
|
||||
# Update mapping for incremental updates
|
||||
tool_index_to_id[tool_call.index] = tool_call.id
|
||||
# Only create new block if different tool call
|
||||
# AND has a name
|
||||
tool_name = (
|
||||
tool_call.function.name
|
||||
if tool_call.function
|
||||
else None
|
||||
)
|
||||
if (
|
||||
state.tool_use_id != tool_call.id
|
||||
and tool_name is not None
|
||||
):
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
start_event = start_block(
|
||||
AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_name,
|
||||
input={},
|
||||
)
|
||||
)
|
||||
yield start_event
|
||||
# Handle initial arguments if present
|
||||
if (
|
||||
tool_call.function
|
||||
and tool_call.function.arguments
|
||||
and state.tool_use_id == tool_call.id
|
||||
):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_delta"
|
||||
)
|
||||
else:
|
||||
# Incremental update - use index to find tool_use_id
|
||||
tool_use_id = tool_index_to_id.get(tool_call.index)
|
||||
if (
|
||||
tool_use_id is not None
|
||||
and tool_call.function
|
||||
and tool_call.function.arguments
|
||||
and state.tool_use_id == tool_use_id
|
||||
):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_delta"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
error_response = AnthropicStreamEvent(
|
||||
@@ -481,3 +769,31 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
request: AnthropicCountTokensRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AnthropicCountTokensResponse | ErrorResponse:
|
||||
"""Implements Anthropic's messages.count_tokens endpoint."""
|
||||
chat_req = self._convert_anthropic_to_openai_request(request)
|
||||
result = await self.render_chat_request(chat_req)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
_, engine_prompts = result
|
||||
|
||||
input_tokens = sum( # type: ignore
|
||||
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
|
||||
for prompt in engine_prompts
|
||||
if "prompt_token_ids" in prompt
|
||||
)
|
||||
|
||||
response = AnthropicCountTokensResponse(
|
||||
input_tokens=input_tokens,
|
||||
context_management=AnthropicContextManagement(
|
||||
original_input_tokens=input_tokens
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -7,6 +7,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplateConfig:
|
||||
chat_template: str | None = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
trust_request_chat_template: bool = False
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Path | str | None):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
if chat_template is None:
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
|
||||
from vllm.entrypoints.cli.benchmark.mm_processor import (
|
||||
BenchmarkMMProcessorSubcommand,
|
||||
)
|
||||
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
|
||||
|
||||
# Keep this package init import-free.
|
||||
#
|
||||
# The `vllm` console script imports `vllm.entrypoints.cli.main`, which causes
|
||||
# Python to import this package before loading the `main` submodule.
|
||||
# Eagerly importing benchmark subcommands here makes every `vllm serve ...`
|
||||
# startup depend on optional benchmark-only modules.
|
||||
#
|
||||
# Benchmark subcommands are loaded on demand in
|
||||
# `vllm.entrypoints.cli.benchmark.main`.
|
||||
__all__: list[str] = [
|
||||
"BenchmarkLatencySubcommand",
|
||||
"BenchmarkMMProcessorSubcommand",
|
||||
"BenchmarkServingSubcommand",
|
||||
"BenchmarkStartupSubcommand",
|
||||
"BenchmarkSweepSubcommand",
|
||||
"BenchmarkThroughputSubcommand",
|
||||
]
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
@@ -15,30 +13,6 @@ if typing.TYPE_CHECKING:
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_benchmark_subcommands() -> None:
|
||||
modules = [
|
||||
"vllm.entrypoints.cli.benchmark.latency",
|
||||
"vllm.entrypoints.cli.benchmark.mm_processor",
|
||||
"vllm.entrypoints.cli.benchmark.serve",
|
||||
"vllm.entrypoints.cli.benchmark.startup",
|
||||
"vllm.entrypoints.cli.benchmark.sweep",
|
||||
"vllm.entrypoints.cli.benchmark.throughput",
|
||||
]
|
||||
|
||||
for module_name in modules:
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.warning(
|
||||
"Skipping benchmark subcommand module %s because an optional "
|
||||
"dependency could not be imported: %r",
|
||||
module_name,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkSubcommand(CLISubcommand):
|
||||
"""The `bench` subcommand for the vLLM CLI."""
|
||||
@@ -64,8 +38,6 @@ class BenchmarkSubcommand(CLISubcommand):
|
||||
)
|
||||
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
|
||||
|
||||
_load_benchmark_subcommands()
|
||||
|
||||
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
|
||||
cmd_subparser = bench_subparsers.add_parser(
|
||||
cmd_cls.name,
|
||||
|
||||
@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
|
||||
# TODO(wentao): remove this once well tested
|
||||
raise ValueError(
|
||||
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
|
||||
)
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
@@ -246,8 +252,12 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
api_server_manager: APIServerProcessManager | None = None
|
||||
|
||||
from vllm.v1.engine.utils import get_engine_zmq_addresses
|
||||
|
||||
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, num_api_servers
|
||||
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
|
||||
@@ -101,11 +101,15 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
sampling_params = self._sampling_params_from_proto(
|
||||
request.sampling_params, stream=request.stream
|
||||
)
|
||||
tokenization_kwargs = self._tokenization_kwargs_from_proto(
|
||||
request.sampling_params
|
||||
)
|
||||
|
||||
async for output in self.async_llm.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
):
|
||||
# Convert vLLM output to protobuf
|
||||
# For streaming, always send chunks
|
||||
@@ -308,9 +312,6 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
seed=params.seed if params.HasField("seed") else None,
|
||||
include_stop_str_in_output=params.include_stop_str_in_output,
|
||||
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
|
||||
truncate_prompt_tokens=params.truncate_prompt_tokens
|
||||
if params.HasField("truncate_prompt_tokens")
|
||||
else None,
|
||||
structured_outputs=structured_outputs,
|
||||
# detokenize must be True if stop strings are used
|
||||
detokenize=bool(stop),
|
||||
@@ -319,6 +320,14 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _tokenization_kwargs_from_proto(
|
||||
params: vllm_engine_pb2.SamplingParams,
|
||||
) -> dict[str, int] | None:
|
||||
if params.HasField("truncate_prompt_tokens"):
|
||||
return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
|
||||
"""
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cloudpickle
|
||||
@@ -41,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
load_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreMultiModalParam,
|
||||
@@ -146,6 +149,7 @@ class LLM:
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id.
|
||||
chat_template: The chat template to apply.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
@@ -233,6 +237,7 @@ class LLM:
|
||||
quantization: QuantizationMethods | None = None,
|
||||
revision: str | None = None,
|
||||
tokenizer_revision: str | None = None,
|
||||
chat_template: Path | str | None = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
@@ -385,9 +390,16 @@ class LLM:
|
||||
|
||||
self.model_config = self.llm_engine.model_config
|
||||
self.renderer = self.llm_engine.renderer
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
|
||||
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
|
||||
self.init_pooling_io_processors = init_pooling_io_processors(
|
||||
supported_tasks=supported_tasks,
|
||||
model_config=self.model_config,
|
||||
renderer=self.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
# Cache for __repr__ to avoid repeated collective_rpc calls
|
||||
self._cached_repr: str | None = None
|
||||
|
||||
@@ -1030,7 +1042,6 @@ class LLM:
|
||||
prompts: PromptType | Sequence[PromptType] | DataPrompt,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
*,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
pooling_task: PoolingTask | None = None,
|
||||
@@ -1088,21 +1099,7 @@ class LLM:
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
warnings.warn(
|
||||
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
|
||||
"is deprecated and will be removed in v0.16. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
@@ -1136,6 +1133,31 @@ class LLM:
|
||||
for p in params_seq:
|
||||
if p.task is None:
|
||||
p.task = "plugin"
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
else:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
@@ -1153,39 +1175,42 @@ class LLM:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if use_io_processor:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
if pooling_task in self.init_pooling_io_processors:
|
||||
io_processor = self.init_pooling_io_processors[pooling_task]
|
||||
processor_inputs = io_processor.pre_process_offline(
|
||||
prompts_seq, tokenization_kwargs
|
||||
)
|
||||
]
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(prompts_seq)
|
||||
)
|
||||
seq_priority = self._priority_to_seq(None, len(prompts))
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=processor_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(
|
||||
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
|
||||
)
|
||||
outputs = io_processor.post_process(outputs)
|
||||
else:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def embed(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
*,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
@@ -1221,12 +1246,6 @@ class LLM:
|
||||
"Try converting the model using `--convert embed`."
|
||||
)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
items = self.encode(
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
@@ -1294,7 +1313,6 @@ class LLM:
|
||||
/,
|
||||
*,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
@@ -1319,13 +1337,11 @@ class LLM:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
"""
|
||||
|
||||
return self.encode(
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
pooling_task="token_classify",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
@@ -1771,23 +1787,15 @@ class LLM:
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(params, len(seq_prompts))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
return self._render_and_add_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
self._preprocess_cmpl_one(prompt, tokenization_kwargs)
|
||||
for prompt in maybe_tqdm(
|
||||
seq_prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering prompts",
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
@@ -1841,13 +1849,6 @@ class LLM:
|
||||
seq_convs = conversation_to_seq(messages)
|
||||
seq_params = self._params_to_seq(params, len(seq_convs))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
|
||||
return self._render_and_run_requests(
|
||||
prompts=(
|
||||
@@ -1859,16 +1860,13 @@ class LLM:
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
for conversation, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_convs,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering conversations",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
for conversation in maybe_tqdm(
|
||||
seq_convs,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering conversations",
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
|
||||
@@ -18,6 +18,20 @@ class RequestLogger:
|
||||
def __init__(self, *, max_log_len: int | None) -> None:
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
if not logger.isEnabledFor(logging.INFO):
|
||||
logger.warning_once(
|
||||
"`--enable-log-requests` is set but "
|
||||
"the minimum log level is higher than INFO. "
|
||||
"No request information will be logged."
|
||||
)
|
||||
elif not logger.isEnabledFor(logging.DEBUG):
|
||||
logger.info_once(
|
||||
"`--enable-log-requests` is set but "
|
||||
"the minimum log level is higher than DEBUG. "
|
||||
"Only limited information will be logged to minimize overhead. "
|
||||
"To view more details, set `VLLM_LOGGING_LEVEL=DEBUG`."
|
||||
)
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.logprobs import Logprob
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RepetitionDetectionParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
@@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
repetition_detection: RepetitionDetectionParams | None = Field(
|
||||
default=None,
|
||||
description="Parameters for detecting repetitive N-gram patterns "
|
||||
"in output tokens. If such repetition is detected, generation will "
|
||||
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
|
||||
"token patterns, stopping only when they hit the maximum output length "
|
||||
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
|
||||
"can detect such behavior and terminate early, saving time and tokens.",
|
||||
)
|
||||
|
||||
# --8<-- [end:chat-completion-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
@@ -490,7 +501,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
@@ -500,8 +510,37 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
repetition_detection=self.repetition_detection,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_response_format(cls, data):
|
||||
response_format = data.get("response_format")
|
||||
if response_format is None:
|
||||
return data
|
||||
|
||||
rf_type = (
|
||||
response_format.get("type")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "type", None)
|
||||
)
|
||||
|
||||
if rf_type == "json_schema":
|
||||
json_schema = (
|
||||
response_format.get("json_schema")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "json_schema", None)
|
||||
)
|
||||
if json_schema is None:
|
||||
raise VLLMValidationError(
|
||||
"When response_format type is 'json_schema', the "
|
||||
"'json_schema' field must be provided.",
|
||||
parameter="response_format",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
|
||||
@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}
|
||||
),
|
||||
ensure_ascii=False,
|
||||
# parsing which "autocompletes" the JSON.
|
||||
# Tool parsers (e.g. Qwen3Coder) store
|
||||
# arguments as a JSON string in
|
||||
# prev_tool_call_arr. Calling json.dumps()
|
||||
# on an already-serialized string would
|
||||
# double-serialize it (e.g. '{"k":1}' becomes
|
||||
# '"{\\"k\\":1}"'), which then causes the
|
||||
# replace() below to fail and append the
|
||||
# entire double-serialized string as a
|
||||
# spurious final delta.
|
||||
args = tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}
|
||||
)
|
||||
if isinstance(args, str):
|
||||
expected_call = args
|
||||
else:
|
||||
expected_call = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
|
||||
@@ -143,7 +143,8 @@ class BaseFrontendArgs:
|
||||
templates and other tokenizer configuration."""
|
||||
enable_log_outputs: bool = False
|
||||
"""If set to True, log model outputs (generations).
|
||||
Requires --enable-log-requests."""
|
||||
Requires `--enable-log-requests`. As with `--enable-log-requests`,
|
||||
information is only logged at INFO level at maximum."""
|
||||
enable_log_deltas: bool = True
|
||||
"""If set to False, output deltas will not be logged. Relevant only if
|
||||
--enable-log-outputs is set.
|
||||
@@ -277,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs):
|
||||
Enable offline FastAPI documentation for air-gapped environments.
|
||||
Uses vendored static assets bundled with vLLM.
|
||||
"""
|
||||
use_gpu_for_pooling_score: bool = False
|
||||
"""If set, run pooling score MaxSim on GPU in the API server process.
|
||||
Can significantly improve late-interaction scoring performance.
|
||||
https://github.com/vllm-project/vllm/pull/35330"""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RepetitionDetectionParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
@@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
repetition_detection: RepetitionDetectionParams | None = Field(
|
||||
default=None,
|
||||
description="Parameters for detecting repetitive N-gram patterns "
|
||||
"in output tokens. If such repetition is detected, generation will "
|
||||
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
|
||||
"token patterns, stopping only when they hit the maximum output length "
|
||||
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
|
||||
"can detect such behavior and terminate early, saving time and tokens.",
|
||||
)
|
||||
|
||||
# --8<-- [end:completion-extra-params]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
@@ -259,7 +270,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
structured_outputs_kwargs["json"] = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
assert isinstance(
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
@@ -302,7 +313,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
@@ -311,8 +321,37 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
repetition_detection=self.repetition_detection,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_response_format(cls, data):
|
||||
response_format = data.get("response_format")
|
||||
if response_format is None:
|
||||
return data
|
||||
|
||||
rf_type = (
|
||||
response_format.get("type")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "type", None)
|
||||
)
|
||||
|
||||
if rf_type == "json_schema":
|
||||
json_schema = (
|
||||
response_format.get("json_schema")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "json_schema", None)
|
||||
)
|
||||
if json_schema is None:
|
||||
raise VLLMValidationError(
|
||||
"When response_format type is 'json_schema', the "
|
||||
"'json_schema' field must be provided.",
|
||||
parameter="response_format",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_structured_outputs_count(cls, data):
|
||||
|
||||
@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
|
||||
| TokenizeCompletionRequest
|
||||
| DetokenizeRequest
|
||||
| EmbeddingCompletionRequest
|
||||
| ClassificationCompletionRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
|
||||
ChatCompletionRequest
|
||||
| TokenizeChatRequest
|
||||
| EmbeddingChatRequest
|
||||
| ClassificationChatRequest
|
||||
| PoolingChatRequest
|
||||
)
|
||||
|
||||
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
|
||||
| TranscriptionResponse
|
||||
| TokenizeResponse
|
||||
| PoolingResponse
|
||||
| ClassificationResponse
|
||||
| ScoreResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
A short string prepended to every request’s ID (e.g. "embd")
|
||||
so you can easily tell “this ID came from Embedding.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -456,7 +447,7 @@ class OpenAIServing:
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
to prepare `ctx` (embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -817,7 +808,7 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# Note: EmbeddingRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
@@ -828,8 +819,6 @@ class OpenAIServing:
|
||||
ScoreTextRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
RerankRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
@@ -839,8 +828,6 @@ class OpenAIServing:
|
||||
ScoreDataRequest: "score",
|
||||
ScoreTextRequest: "score",
|
||||
ScoreQueriesDocumentsRequest: "score",
|
||||
ClassificationCompletionRequest: "classification",
|
||||
ClassificationChatRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise VLLMValidationError(
|
||||
|
||||
@@ -328,8 +328,9 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
# Also check text.format for OpenAI-style json_schema
|
||||
if self.text is not None and self.text.format is not None:
|
||||
if structured_outputs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both structured_outputs and text.format"
|
||||
raise VLLMValidationError(
|
||||
"Cannot specify both structured_outputs and text.format",
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
response_format = self.text.format
|
||||
if (
|
||||
@@ -378,14 +379,19 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_background(cls, data):
|
||||
if not data.get("background"):
|
||||
return data
|
||||
if not data.get("store", True):
|
||||
raise ValueError("background can only be used when `store` is true")
|
||||
raise VLLMValidationError(
|
||||
"background can only be used when `store` is true",
|
||||
parameter="background",
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise VLLMValidationError(
|
||||
@@ -394,16 +400,19 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None and (
|
||||
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
raise VLLMValidationError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided.",
|
||||
parameter="cache_salt",
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def function_call_parsing(cls, data):
|
||||
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
|
||||
This ensures Pydantic can properly resolve union types in the input field.
|
||||
|
||||
@@ -85,6 +85,8 @@ from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputOutputMessage,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseUsage,
|
||||
@@ -1339,6 +1341,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartAddedEvent(
|
||||
type="response.reasoning_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text="",
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
@@ -1354,22 +1369,21 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
first_delta_sent = True
|
||||
# todo(kebe7jun) tool call support
|
||||
|
||||
@@ -1397,6 +1411,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
text=reason_content,
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text=reason_content,
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index = 0
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
@@ -1418,6 +1445,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
item=reasoning_item,
|
||||
)
|
||||
)
|
||||
current_output_index += 1
|
||||
current_item_id = str(uuid.uuid4())
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
@@ -1432,8 +1461,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_output_index += 1
|
||||
current_item_id = str(uuid.uuid4())
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
@@ -1449,7 +1476,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
# reset previous delta messages
|
||||
previous_delta_messages = []
|
||||
|
||||
@@ -1485,7 +1511,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
|
||||
previous_delta_messages.append(delta_message)
|
||||
if previous_delta_messages:
|
||||
@@ -1505,7 +1530,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
text=reason_content,
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text=reason_content,
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
content=[
|
||||
@@ -1543,7 +1580,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
item_id=current_item_id,
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
part = ResponseOutputText(
|
||||
text=final_content,
|
||||
type="output_text",
|
||||
@@ -1559,7 +1595,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
part=part,
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
item = ResponseOutputMessage(
|
||||
type="message",
|
||||
role="assistant",
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from soundfile import LibsndfileError
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -57,6 +58,14 @@ try:
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
|
||||
# being librosa's main backend. Used to validate if an audio loading error is due to a
|
||||
# server error vs a client error (invalid audio file).
|
||||
# 1 = unrecognised format (file is not a supported audio container)
|
||||
# 3 = malformed file (corrupt or structurally invalid audio)
|
||||
# 4 = unsupported encoding (codec not supported by this libsndfile build)
|
||||
_BAD_SF_CODES = {1, 3, 4}
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
SpeechToTextResponseVerbose: TypeAlias = (
|
||||
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||
@@ -315,9 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
try:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
except LibsndfileError as exc:
|
||||
# Distinguish client errors (invalid audio) from server errors
|
||||
if exc.code in _BAD_SF_CODES:
|
||||
raise ValueError("Invalid or unsupported audio file.") from exc
|
||||
raise
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"The 'vllm.entrypoints.openai.translations' module has been renamed to "
|
||||
"'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
|
||||
"This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.api_router' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.api_router'. Please update your "
|
||||
"imports. This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.api_router import * # noqa: F401,F403,E402
|
||||
@@ -1,14 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.protocol' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.protocol'. Please update your "
|
||||
"imports. This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import * # noqa: F401,F403,E402
|
||||
@@ -1,14 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.serving' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.serving'. Please update your "
|
||||
"imports. This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import * # noqa: F401,F403,E402
|
||||
@@ -1,15 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'vllm.entrypoints.openai.translations.speech_to_text' has been moved to "
|
||||
"'vllm.entrypoints.openai.speech_to_text.speech_to_text'. Please update "
|
||||
"your imports. This backward-compatible alias will be removed in version "
|
||||
"0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import * # noqa: F401,F403,E402
|
||||
@@ -115,6 +115,7 @@ def init_pooling_state(
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
|
||||
)
|
||||
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
|
||||
else None
|
||||
|
||||
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Final
|
||||
|
||||
from vllm import PoolingRequestOutput, PromptType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
|
||||
from vllm.inputs import ProcessorInputs, SingletonPrompt
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
|
||||
class PoolingIOProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
):
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.model_config = model_config
|
||||
self.renderer = renderer
|
||||
|
||||
self.chat_template = chat_template_config.chat_template
|
||||
self.chat_template_content_format: Final = (
|
||||
chat_template_config.chat_template_content_format
|
||||
)
|
||||
self.trust_request_chat_template = (
|
||||
chat_template_config.trust_request_chat_template
|
||||
)
|
||||
|
||||
def pre_process_online(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_online_async(self, *args, **kwargs):
|
||||
return self.pre_process_online(*args, **kwargs)
|
||||
|
||||
def pre_process_offline(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_offline_async(self, *args, **kwargs):
|
||||
return self.pre_process_offline(*args, **kwargs)
|
||||
|
||||
def post_process(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return outputs
|
||||
|
||||
async def post_process_async(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return self.post_process(outputs)
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params()
|
||||
|
||||
def _preprocess_completion_online(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = request.build_tok_params(model_config)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
def _preprocess_chat_online(
|
||||
self,
|
||||
request: RendererChatRequest,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
)
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
|
||||
(conversation,), (engine_prompt,) = renderer.render_chat(
|
||||
[messages],
|
||||
chat_params,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _preprocess_completion_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = prompt_to_seq(prompts)
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
trust_request_chat_template: bool,
|
||||
):
|
||||
if not trust_request_chat_template and (
|
||||
request_chat_template is not None
|
||||
or (
|
||||
chat_template_kwargs
|
||||
and chat_template_kwargs.get("chat_template") is not None
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template."
|
||||
)
|
||||
return None
|
||||
@@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
|
||||
description="Whether to use activation for the pooler outputs. "
|
||||
"`None` uses the pooler's default, which is `True` in most cases.",
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated; please pass `use_activation` instead",
|
||||
)
|
||||
# --8<-- [end:embed-extra-params]
|
||||
|
||||
|
||||
|
||||
378
vllm/entrypoints/pooling/base/serving.py
Normal file
378
vllm/entrypoints/pooling/base/serving.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import ConfigDict
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from vllm import (
|
||||
PoolingParams,
|
||||
PoolingRequestOutput,
|
||||
PromptType,
|
||||
SamplingParams,
|
||||
envs,
|
||||
)
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
from ...utils import create_error_response
|
||||
from .io_processor import PoolingIOProcessor
|
||||
|
||||
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
request: PoolingRequestT
|
||||
raw_request: Request | None = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[ProcessorInputs] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
)
|
||||
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PoolingServing:
|
||||
request_id_prefix: ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.engine_client = engine_client
|
||||
self.models = models
|
||||
self.model_config = models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.log_error_stack = log_error_stack
|
||||
self.chat_template_config = ChatTemplateConfig(
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
self.io_processor = self.init_io_processor(
|
||||
model_config=models.model_config,
|
||||
renderer=models.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request,
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
)
|
||||
|
||||
await self._check_model(request)
|
||||
|
||||
ctx = PoolingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
self._validate_request(ctx)
|
||||
self._maybe_get_adapters(ctx)
|
||||
await self._preprocess(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
response = await self._build_response(ctx)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except Exception as e:
|
||||
error_response = create_error_response(e)
|
||||
return JSONResponse(
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
|
||||
ctx.request
|
||||
)
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
if ctx.result_generator is None:
|
||||
raise ValueError("Result generator not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
raise ValueError("Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> AnyPoolingResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(
|
||||
raw_request: Request | None, default: str | None = None
|
||||
) -> str | None:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
if raw_request is not None and (
|
||||
(req_id := raw_request.headers.get("X-Request-Id")) is not None
|
||||
):
|
||||
return req_id
|
||||
|
||||
return random_uuid() if default is None else default
|
||||
|
||||
def _is_model_supported(self, model_name: str | None) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
) -> ErrorResponse | None:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in self.models.lora_requests:
|
||||
return None
|
||||
if (
|
||||
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
|
||||
and request.model
|
||||
and (load_result := await self.models.resolve_lora(request.model))
|
||||
):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if (
|
||||
isinstance(load_result, ErrorResponse)
|
||||
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
|
||||
):
|
||||
raise ValueError(load_result.error.message)
|
||||
return None
|
||||
|
||||
def _validate_request(self, ctx: PoolingServeContext) -> None:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
|
||||
|
||||
if (
|
||||
truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len
|
||||
):
|
||||
raise ValueError(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Mapping[str, str] | None:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
supports_default_mm_loras: bool = False,
|
||||
):
|
||||
request = ctx.request
|
||||
if request.model in self.models.lora_requests:
|
||||
ctx.lora_request = self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
ctx.lora_request = default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyPoolingRequest
|
||||
) -> LoRARequest | None:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
messages = request.messages
|
||||
if messages is None or isinstance(messages, (str, bytes)):
|
||||
return message_types
|
||||
|
||||
for message in messages:
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and "content" in message
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
components = extract_prompt_components(self.model_config, inputs)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
components.text,
|
||||
components.token_ids,
|
||||
components.embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -3,16 +3,17 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.entrypoints.utils import (
|
||||
create_error_response,
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
async def create_classify(
|
||||
request: ClassificationRequest, raw_request: Request
|
||||
) -> JSONResponse:
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
error_response = create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm import PromptType
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
)
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
|
||||
|
||||
class ClassifyIOProcessor(PoolingIOProcessor):
|
||||
def pre_process_online(
|
||||
self, request: ClassificationCompletionRequest | ClassificationChatRequest
|
||||
) -> list[TokPrompt] | None:
|
||||
if isinstance(request, ClassificationChatRequest):
|
||||
self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
_, engine_prompts = self._preprocess_chat_online(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, ClassificationCompletionRequest):
|
||||
engine_prompts = self._preprocess_completion_online(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid classification request type")
|
||||
return engine_prompts
|
||||
|
||||
def pre_process_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=prompts, tokenization_kwargs=tokenization_kwargs
|
||||
)
|
||||
@@ -40,7 +40,6 @@ class ClassificationCompletionRequest(
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
@@ -63,7 +62,6 @@ class ClassificationChatRequest(
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,116 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
from vllm import ClassificationOutput
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ClassifyIOProcessor
|
||||
from .protocol import (
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
class ServingClassification(PoolingServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
def init_io_processor(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def _preprocess(
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
) -> ClassificationResponse:
|
||||
final_res_batch_checked = await self.io_processor.post_process_async(
|
||||
ctx.final_res_batch
|
||||
)
|
||||
|
||||
if isinstance(ctx.request, ClassificationChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
id2label = getattr(self.model_config.hf_config, "id2label", {})
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
|
||||
@@ -14,12 +14,9 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
EmbedRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
@@ -60,18 +57,10 @@ class EmbeddingCompletionRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.use_activation,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -97,18 +86,10 @@ class EmbeddingChatRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.use_activation,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
31
vllm/entrypoints/pooling/io_processor_factories.py
Normal file
31
vllm/entrypoints/pooling/io_processor_factories.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tasks import SupportedTask
|
||||
|
||||
|
||||
def init_pooling_io_processors(
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> dict[str, PoolingIOProcessor]:
|
||||
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.io_processor import (
|
||||
ClassifyIOProcessor,
|
||||
)
|
||||
|
||||
pooling_io_processors["classify"] = ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
return pooling_io_processors
|
||||
@@ -16,13 +16,10 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
EncodingRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PoolingCompletionRequest(
|
||||
PoolingBasicRequestMixin,
|
||||
@@ -45,16 +42,8 @@ class PoolingCompletionRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task=self.task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
@@ -78,16 +67,8 @@ class PoolingChatRequest(
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
if self.normalize is not None:
|
||||
logger.warning_once(
|
||||
"`normalize` is deprecated and will be removed in v0.17. "
|
||||
"Please pass `use_activation` instead."
|
||||
)
|
||||
self.use_activation = self.normalize
|
||||
|
||||
return PoolingParams(
|
||||
task=self.task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||
return PoolingParams(
|
||||
task=task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
@@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||
return PoolingParams(
|
||||
task=task,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
compute_maxsim_scores,
|
||||
get_score_prompt,
|
||||
parse_score_data_single,
|
||||
validate_score_input,
|
||||
@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
use_gpu_for_pooling_score: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.score_template = score_template
|
||||
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@@ -311,19 +313,18 @@ class ServingScores(OpenAIServing):
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
maxsim_scores = compute_maxsim_scores(
|
||||
[emb.outputs.data for emb in emb_data_1],
|
||||
[emb.outputs.data for emb in emb_data_2],
|
||||
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
|
||||
)
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import torch
|
||||
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
@@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
|
||||
return use_gpu_for_pooling_score and not current_platform.is_cpu()
|
||||
|
||||
|
||||
def compute_maxsim_scores(
|
||||
q_embs: Sequence[torch.Tensor],
|
||||
d_embs: Sequence[torch.Tensor],
|
||||
max_batch_size: int = 16,
|
||||
max_score_matrix_elements: int = 16_000_000,
|
||||
use_gpu_for_pooling_score: bool = False,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Compute ColBERT MaxSim scores in padded mini-batches."""
|
||||
if len(q_embs) != len(d_embs):
|
||||
raise ValueError("q_embs and d_embs must have the same length")
|
||||
|
||||
num_pairs = len(q_embs)
|
||||
if num_pairs == 0:
|
||||
return []
|
||||
|
||||
for q_emb, d_emb in zip(q_embs, d_embs):
|
||||
if q_emb.ndim != 2 or d_emb.ndim != 2:
|
||||
raise ValueError("Each embedding tensor must be 2-D")
|
||||
if q_emb.shape[1] != d_emb.shape[1]:
|
||||
raise ValueError("Query and document embeddings must have same dim")
|
||||
|
||||
compute_device = torch.device(
|
||||
current_platform.device_type
|
||||
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
|
||||
else "cpu"
|
||||
)
|
||||
scores: list[torch.Tensor] = []
|
||||
start = 0
|
||||
while start < num_pairs:
|
||||
end = min(start + max_batch_size, num_pairs)
|
||||
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
|
||||
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
|
||||
|
||||
# keep score matrix bounded to avoid oversized allocations.
|
||||
while (
|
||||
end - start > 1
|
||||
and (end - start) * max_q * max_d > max_score_matrix_elements
|
||||
):
|
||||
end -= 1
|
||||
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
|
||||
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
|
||||
|
||||
batch_q = q_embs[start:end]
|
||||
batch_d = d_embs[start:end]
|
||||
batch_size = end - start
|
||||
dim = int(batch_q[0].shape[1])
|
||||
dtype = batch_q[0].dtype
|
||||
|
||||
q_batch = torch.zeros(
|
||||
(batch_size, max_q, dim), dtype=dtype, device=compute_device
|
||||
)
|
||||
d_batch = torch.zeros(
|
||||
(batch_size, max_d, dim), dtype=dtype, device=compute_device
|
||||
)
|
||||
q_mask = torch.zeros(
|
||||
(batch_size, max_q), dtype=torch.bool, device=compute_device
|
||||
)
|
||||
d_mask = torch.zeros(
|
||||
(batch_size, max_d), dtype=torch.bool, device=compute_device
|
||||
)
|
||||
|
||||
# copy to padded tensors
|
||||
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
|
||||
q_len = int(q_emb.shape[0])
|
||||
d_len = int(d_emb.shape[0])
|
||||
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
|
||||
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
|
||||
q_mask[i, :q_len] = True
|
||||
d_mask[i, :d_len] = True
|
||||
|
||||
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
|
||||
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
|
||||
max_per_query = token_scores.amax(dim=-1)
|
||||
max_per_query.masked_fill_(~q_mask, 0)
|
||||
batch_scores = max_per_query.sum(dim=-1).to("cpu")
|
||||
scores.extend(batch_scores.unbind(0))
|
||||
start = end
|
||||
|
||||
return [cast(torch.Tensor, score) for score in scores]
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
51
vllm/entrypoints/pooling/typing.py
Normal file
51
vllm/entrypoints/pooling/typing.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
|
||||
PoolingCompletionLikeRequest: TypeAlias = (
|
||||
EmbeddingCompletionRequest
|
||||
| ClassificationCompletionRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
)
|
||||
|
||||
PoolingChatLikeRequest: TypeAlias = (
|
||||
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
|
||||
)
|
||||
|
||||
AnyPoolingRequest: TypeAlias = (
|
||||
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
|
||||
)
|
||||
|
||||
AnyPoolingResponse: TypeAlias = (
|
||||
ClassificationResponse
|
||||
| EmbeddingResponse
|
||||
| EmbeddingBytesResponse
|
||||
| PoolingResponse
|
||||
| ScoreResponse
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.serve.instrumentator.basic import base
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import asyncio
|
||||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from http import HTTPStatus
|
||||
from logging import Logger
|
||||
from string import Template
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import current_formatter_type, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||
else:
|
||||
StreamOptions = object
|
||||
ErrorResponse = object
|
||||
ErrorInfo = object
|
||||
LoRAModulePath = object
|
||||
|
||||
StreamOptions = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
|
||||
message = logo_template.substitute(colors)
|
||||
|
||||
lgr.info(message, version, model_name)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str | Exception,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
param: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> "ErrorResponse":
|
||||
exc: Exception | None = None
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
|
||||
|
||||
if isinstance(message, Exception):
|
||||
exc = message
|
||||
|
||||
if isinstance(exc, VLLMValidationError):
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = exc.parameter
|
||||
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
|
||||
# Common validation errors from user input
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
err_type = "NotImplementedError"
|
||||
status_code = HTTPStatus.NOT_IMPLEMENTED
|
||||
param = None
|
||||
elif exc.__class__.__name__ == "TemplateError":
|
||||
# jinja2.TemplateError (avoid importing jinja2)
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
else:
|
||||
err_type = "InternalServerError"
|
||||
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
param = None
|
||||
|
||||
message = str(exc)
|
||||
|
||||
if log_error_stack:
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type is not None:
|
||||
traceback.print_exc()
|
||||
else:
|
||||
traceback.print_stack()
|
||||
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(message),
|
||||
type=err_type,
|
||||
code=status_code.value,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user